在使用pytorch做深度学习任务的数据加载时,常用的方式是使用torchvision.Dataset
类定义数据读取,然后使用torch.utils.data.DataLoader
定义数据加载器。该部分内容见Pytorch学习(一)
不过,有些分类数据的文件目录组织形式如下:
即默认你的数据集已经自觉按照要分配的类型分成了不同的文件夹,一种类型的文件夹下面只存放一种类型的图片。
这时候,定义数据读取时,使用 torchvision
包中的ImageFolder
类会比Dataset
类会更方便。
ImageFolder
CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
一个通用数据加载器,其中图像以这种方式排列:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
参数:
root (string) – 指定图片存储的路径
transform (callable, optional) – 一个transform函数,接受PIL.Image图像并返回一个转换后的图片
target_transform (callable, optional) –一个函数,输入为target,输出对其的转换。
loader (callable, optional) – A function to load an image given its path.
is_valid_file – 该函数获取图像文件的路径并检查该文件是否为有效文件
成员变量:
-
self.classes - 用一个list保存 类名
-
self.class_to_idx - 类名对应的 索引
-
self.imgs - 保存(img-path, class) tuple的list。
return self.imgs
即相当于Dataset
类中的return (img, target)
例子:
# 指定读取的图片路径
train_root = './train/
# transform函数组合
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
transforms.RandomHorizontalFlip(),
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
# 使用ImageFolder读取数据
all_data = torchvision.datasets.ImageFolder(
root=train_root,
transform=train_transform
)
# 定义数据加载器
train_set = torch.utils.data.DataLoader(
all_data,
batch_size=BTACH_SIZE,
shuffle=True
)
参考
https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder
https://www.jb51.net/article/180916.htm