对于分类存储的图片,pytorch可以用ImageFolder直接读取,非常方便,但是当需要把训练集划分为训练加验证的话,这个就不太能胜任了。
参考将分类存储的图片切分为训练集、验证集和测试集(PyTorch实现),可以把数据集划分为训练集和数据集,根据自己的数据集和需求小改了一下代码。
原文是针对所有类别样本数目都一样写的,我改成了当每个类别样本数目不一样的时候怎么按比例划分。
from torchvision.datasets import ImageFolder
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transformer_ImageNet = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
val_transformer_ImageNet = transforms.Compose([
transforms.Resize(224),
tran