现有以下完整程序可以成功加载数据集,使用ImageFolder函数:
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms,utils
import numpy as np
# 使用ImageFolder需要保证数据集以下列形式组织:
'''
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
'''
img_data = torchvision.datasets.ImageFolder(
root = r'E:\机器学习数据集\flower_photos',
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
)
print('数据集类别:',img_data.classes)
print('数据集大小:',len(img_data))
# 使用torch.utils.data.DataLoader加载,形成一个DataLoader类实例
data_loader = torch.utils.data.DataLoader(img_data,batch_size=36, shuffle=True)
print(len(data_loader))
def imshow(img):
# img = img / 2 + 0.5 # unnormalize
img = torchvision.utils.make_grid(img, nrow=6)
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.title('Batch from dataloader')
plt.xticks([])
plt.yticks([])
plt.show()
# get some random training images
dataiter = iter(data_loader)
images, labels = dataiter.next()
print(images.shape, labels)
# show images
imshow(images)
上面程序用了三种变换:Resize,Crop和ToTensor,问题就出现在这里了,
- 问题1:如果去掉Crop,只留下Resize和ToTensor,程序报错:
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 408 and 341 in dimension 3 at ..\aten\src\TH/generic/THTensor.cpp:711
- 问题2 :去掉ToTensor或者将ToTensor放到前面而不是最后一项,程序报错:
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.Image.Image'>
其他情况不好有问题,先留着这两个问题,以后研究深入了再解答。