首先进行导入 from torchvision import datasets
1.datasets.ImageFolder()是一个通用的数据加载器
dataset=torchvision.datasets.ImageFolder(
root, transform=None,
target_transform=None,
loader=<function default_loader>,
is_valid_file=None)
参数:
- root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
- transform:对图片进行预处理的操作函数,原始图片作为输入,返回一个转换后的图片。
- target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。 如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
- loader:表示数据集加载方式,通常默认加载方式即可。
- is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)
返回的dataset有三个属性:
- dataset.classes:输出为类别列表
- dataset.class_to_idx:输出为{类别名:索引}的字典
- dataset.imgs:输出为[(图片路径,索引),...]的列表
from torchvision.datasets import datasets
from torchvision import transforms
data_transform = {
"train": transforms.Compose([transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(degrees=(0, 180)),
transforms.ColorJitter(brightness=(0.9, 1.2),
contrast=(0.9, 1.2),
saturation=(0.8, 1.3)),
transforms.ToTensor()}
dataset=datasets.ImageFolder(root=args.train_path, transform=data_transform["train"])
那么得到的dataset,它的结构就是[(img_data,class_id),(img_data,class_id),…]
print(dataset[0])
# 输出:
(tensor([[[-0.5137, -0.4667, -0.4902, ..., -0.0980, -0.0980, -0.0902],
[-0.5922, -0.5529, -0.5059, ..., -0.0902, -0.0980, -0.0667],
[-0.5373, -0.5294, -0.4824, ..., -0.0588, -0.0824, -0.0196],
...,
[-0.3098, -0.3882, -0.3725, ..., -0.4353, -0.4510, -0.4196],
[-0.2863, -0.3647, -0.3725, ..., -0.4431, -0.4118, -0.4196],
[-0.3412, -0.3569, -0.3882, ..., -0.4667, -0.4588, -0.4196]],
[[-0.6627, -0.6157, -0.6549, ..., -0.5059, -0.5216, -0.5137],
[-0.7412, -0.7020, -0.6706, ..., -0.4980, -0.5216, -0.4902],
[-0.6863, -0.6784, -0.6471, ..., -0.4667, -0.4902, -0.4275],
...,
[-0.6000, -0.6549, -0.6627, ..., -0.6784, -0.6941, -0.6627],
[-0.5765, -0.6314, -0.6471, ..., -0.6863, -0.6549, -0.6627],
[-0.6314, -0.6314, -0.6392, ..., -0.7098, -0.7020, -0.6627]]]), 0)