刚开始学习PyTorch,找了很多自定义数据加载的方法,还是使用torch中封装的库函数好用,而且快捷,会根据路径自动返回对应的标签,下面的代码每一行都给了注释。
import torch
from torchvision import transforms, utils
from torchvision import datasets
import torch.utils.data
import matplotlib.pyplot as plt
# 定义图像预处理
transform1 = transforms.Compose([ # 这里最好加上一个中括号,否则会被认为是意外实参
transforms.RandomHorizontalFlip(p=0.3), # 随机水平翻转,概率为0.3
transforms.RandomVerticalFlip(p=0.3), # 随机垂直翻转,概率为0.3
transforms.Resize((32, 32)), # 转换成32*32类型,便于收敛,此外大多数论文中也采用了这种大小,便于借鉴
transforms.ToTensor(), # 转换成Tensor类型
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # 进行归一化处理,可以直接套用
])
train_data = datasets.ImageFolder(r"C:\Users\asus\Desktop\cnn_data\cnn_data\data\training_data", transform=transform1)
print(train_data.classes) # 返回训练集的标签
# 使