part 1:
利用ImageFolder读入数据,可以不重写dataloader,直接写在train.py。
from torchvision.datasets import ImageFolder
这篇博客写的很好
https://blog.csdn.net/weixin_40123108/article/details/85099449?spm=1001.2014.3001.5506
torchvision已经预先加载了常用的Dataset,包括前面使用过的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集,可通过诸如torchvision.datasets.CIFAR10来调用。
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
root:在root指定的路径下寻找图片
transform:对 Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
target_transform:对label转换
loader:给定路径后如何读取图片,默认读取为RGB格式的 Image对象
label是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一致,如果不是这种命名规范,建议看看self.class_to_idx属性以了解label和文件夹名的映射关系。
可以如下直接写在train
transform = transforms.Compose([
# you can add other transformations in this list
transforms.RandomResizedCrop(124),
transforms.RandomHorizontalFlip(),
transforms.Grayscale(num_output_channels=1), #灰度转化
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225]),
])
dataset_train= ImageFolder('./img1/Train',transform=transform)
print(dataset_train.imgs) #输出类名
print(dataset_train.class_to_idx) #输出文件路径和类别
dataset_valid= ImageFolder("./img1/Valid",transform=transform)
dataset_test= ImageFolder("./img1/Test",transform=transform)
print(dataset_test.class_to_idx)
print(dataset_train[0][0])
需要的文件格式如下:
输出为:
part 2:重写dataloader