torch.utils.data.DataLoader中collate_fn函数的使用

研究了两天,终于把相关代码跑通了,简单来说就是对于一般的dataset(这里的一般指的是比较常见的cv任务,但具体我也不是特别了解),在使用torch.utils.data.DataLoader时这个参数无需指定,pytorch会直接调用default_collate(),但是我的任务中,一个batch中的向量,在很多个维度上都不统一,所以需要自己写一个collate_fn()函数,使得同一个batch内的向量保持维度一致,才能够开展并行计算。

具体以我的代码为例:

def my_collate_fn(batch):
    # input: a list of bsz dicts,
    # the structure of each one(batch) is:
    # 'frame_label': a tensor and a string:
    #        [fms, C, H, W]   ('AO8RW',)
    # 'query_sent': ['a person is putting a book on a shelf.']
    # 'start_frame': tensor([1])
    # 'end_frame': tensor([809])
    # 'clip_start_frame': tensor([0])
    # 'clip_end_frame': tensor([165])
    # print(len(batch))

    fms_list = [e['frames_label'][0].shape[0] for e in batch]
    max_fms = max(fms_list)
    pack_fms = [(max_fms - f) for f in fms_list]

    crop = torchvision.transforms.CenterCrop([224, 224])

    batch_tensor = {}
    batch_tensor['frames'] = torch.empty(config.bsz, max_fms, 3, 224, 224)
    batch_tensor['video_name'] = []
    batch_tensor['query_sent'] = []
    batch_tensor['start_frame'] = torch.empty(config.bsz)
    batch_tensor['end_frame'] = torch.empty(config.bsz)
    batch_tensor['clip_start_frame'] = torch.empty(config.bsz)
    batch_tensor['clip_end_frame'] = torch.empty(config.bsz)
    for i, video in enumerate(batch):
        batch_tensor['frames'][i] = torch.empty(max_fms, 3, 224, 224)
        for frame in range(video['frames_label'][0].shape[0]):
            batch_tensor['frames'][i][frame] = crop(video['frames_label'][0][frame])
        torch.nn.functional.pad(batch_tensor['frames'][i], (0, 0, 0, 0, 0, 0, 0, fms_list[i]))
        batch_tensor['video_name'].append(video['frames_label'][1])
        batch_tensor['query_sent'].append(video['query_sent'])
        batch_tensor['start_frame'][i] = video['start_frame']
        batch_tensor['end_frame'][i] = video['end_frame']
        batch_tensor['clip_start_frame'][i] = video['clip_start_frame']
        batch_tensor['clip_end_frame'][i] = video['clip_end_frame']

    # print(len(batch_tensor))
    return batch_tensor

当指定自己写的collate_fn函数时(此处命名为my_collate_fn),会将从dataset中取样得到的batch_size个对象以list的类型传送过去,然后根据特定的需求对维度和关键字进行处理,返回一个batch,而我得到的batch的结构如下:

#batch structure:
#'frames': bsz, fms, C, H, W
#'video_name': list of bsz
#'query_sent': list of bsz
#'start_frame': bsz
#'end_frame': bsz
#'clip_start_frame': bsz
#'clip_end_frame': bsz

当然,我只是解决了维度和关键字的问题,但是collate_fn的用法绝对不止于此,希望大家今后多多发掘,共同进步~
最后,附一个今天看到的关于torch.utils.data.DataLoader工作原理的动画,真的超棒!

https://mp.weixin.qq.com/s/Uc2LYM6tIOY8KyxB7aQrOw

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值