在实验中,经常遇到图片数据是分类存储的—即不同的文件夹存放不同label的图片,比如猫狗数据集,将猫的数据放在cat文件夹下,而将狗的图片放在名为dog的文件夹。在跑模型的过程中,常常需要将数据切分为训练集、验证集和测试集,并且同一类别的label要一致。
# -*- coding: utf-8 -*-
from torchvision.datasets import ImageFolder
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transformer_ImageNet = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
val_transformer_ImageNet = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
class MyDataset(Dataset):
def __init__(self, filenames, labels, transform):
self.filenames = filenames
self.labels