以MNIST为例
from torchvision import datasets
mnist = datasets.MNIST(root='./data/', train=True, download=True)
print(mnist[0])
结果
(<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5)
MINIST数据集的dataset是由一张图片和一个label组成的元组
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x)
for each in dataloader:
print(each)
break
结果
[(<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105630>, 0), (<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105668>, 2)]
collate_fn为lamda x:x时表示对传入进来的数据不做处理
下面自定义collate_fn看看什么效果
def collate(data):
img = []
label = []
for each in data:
img.append(each[0])
label.append(each[1])
return img,label
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=collate)
for each in dataloader:
print(each)
break
结果
([<PIL.Image.Image image mode=L size=28x28 at 0x241433A36D8>, <PIL.Image.Image image mode=L size=28x28 at 0x241433A3710>], [9, 3])
说明:若不设置collate_fn参数则会使用默认处理函数
但必须保证传进来的数据都是tensor格式否则会报错