使用PyTorch实现)。
```python
import torch
from torchvision import transforms, datasets
# 设定数据路径
data_dir = './flower'
# 转换图片大小
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# 创建Dataset类型的训练集和测试集
image_datasets = {x: datasets.ImageFolder(
root=data_dir+'/'+x, transform=data_transforms[x])
for x in ['train', 'test']}
# 创建DataLoader函数,batch_size=32
dataloaders = {x: torch.utils.data.DataLoader(
image_datasets[x], batch_size=32, shuffle=True)
for x in ['train', 'test']}
# 查看第一个批次的图像及对应标签
inputs, classes = next(iter(dataloaders['train']))
print(inputs.shape, classes)
```
输出:
```
torch.Size([32, 3, 224, 224]) tensor([ 7, 96, 41, 57, 70, 98, 11, 48, 0, 4, 4, 4, 4, 4, 4, 4, 40, 85,
93, 49, 55, 4, 47, 30, 4, 4, 4, 4, 4, 4, 4, 4])
```
其中`inputs`是一个大小为`(32, 3, 224, 224)`的Tensor,表示有32张224x224的RGB图片。`classes`是一个大小为`(32,)`的Tensor,表示对应的32张图片的类别标签。