例子
例如:自定义的dataset返回的单个样本格式为:(string, tensor)
直接用dataloader(dataset)得到的loader是不能够自动把上述格式转换为batch的。
解决方法:
需要自定义一个collate_function用于返回batch。
def collate_function(data):
"""
:data: a list for a batch of samples. [[string, tensor], ..., [string, tensor]]
"""
transposed_data = list(zip(*data))
directorys, imgs = transposed_data[0], transposed_data[1]
imgs = torch.stack(imgs, 0)
return (directorys, imgs)
dataloader = torch.utils.data.DataLoader(Dataset(transforms=data_transforms, train=False),
batch_size=2, collate_fn=collate_function, shuffle=True, num_workers=1, pin_memory=True)