Pytorch如何自定义dataloader的返回格式

例子

例如:自定义的dataset返回的单个样本格式为:(string, tensor)
直接用dataloader(dataset)得到的loader是不能够自动把上述格式转换为batch的。

解决方法:

需要自定义一个collate_function用于返回batch。

def collate_function(data):
	"""
	:data: a list for a batch of samples. [[string, tensor], ..., [string, tensor]]
	"""
    transposed_data = list(zip(*data))
    directorys, imgs = transposed_data[0], transposed_data[1]
    imgs = torch.stack(imgs, 0)
    return (directorys, imgs)

dataloader = torch.utils.data.DataLoader(Dataset(transforms=data_transforms, train=False),
                                         batch_size=2, collate_fn=collate_function, shuffle=True, num_workers=1, pin_memory=True)

参考:

Dataloader的官网源码
PytorchDiscuss

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值