torchvision.datasets.ImageFolder(root, transform, target_transform, loader)
参数:
- root:图片存储的根目录,即各类别文件夹所在目录的上一级目录,在下面的例子中是 “…/input/data/”
- transform:对图片进行预处理操作(函数),原始图片作为输入,返回一个转换后的图片。
- target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
- loader:表示数据集加载方式,通常默认加载方式即可
另外,该 API 有以下成员变量:
- self.classes:用一个 list 保存类别名称
- self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
- self.imgs:保存(img-path, class) tuple的 list,与我们自定义 Dataset类的 def getitem(self, index): 返回值类似。注意看下面实例中 dataset.imgs 的返回值
举例:
数据存储结构如下
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
trans = transforms.Compose([transforms.RandomCrop(224), transforms.ToTensor()])
dataset = torchvision.datasets.ImageFolder('../input/data', transform=trans)
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)
print('\n')
train_loader = data.DataLoader(dataset, batch_size=2, shuffle=True)
for (img, label) in train_loader:
print(img.shape)
print(label)
break