通过使用PyTorch中的torchvision模块,可以较为方便的构建dataset和dataloader
首先确保你的数据是按照这种方式放置
dataset/classA/1.png
dataset/classA/2.png
dataset/classB/1.png
dataset/classB/2.png
import torch
from torchvision import transforms, datasets, utils
data_transform = transforms.Compose([
transforms.RandomResizedCrop(576),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(root='../dataset/', transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(dataset,
batch_size=4, shuffle=True,
num_workers=4)
之后可以检查每个batch的输出
import matplotlib.pylab as plt
for d in dataset_loader:
x = d[0]
y = d[1]
grid = utils.make_grid(x)
plt.imshow(grid.numpy().transpose(1, 2, 0))
plt.show()
可以看到每个batch的输出