研究了两天,终于把相关代码跑通了,简单来说就是对于一般的dataset(这里的一般指的是比较常见的cv任务,但具体我也不是特别了解),在使用torch.utils.data.DataLoader
时这个参数无需指定,pytorch会直接调用default_collate()
,但是我的任务中,一个batch中的向量,在很多个维度上都不统一,所以需要自己写一个collate_fn()
函数,使得同一个batch内的向量保持维度一致,才能够开展并行计算。
具体以我的代码为例:
def my_collate_fn(batch):
# input: a list of bsz dicts,
# the structure of each one(batch) is:
# 'frame_label': a tensor and a string:
# [fms, C, H, W] ('AO8RW',)
# 'query_sent': ['a person is putting a book on a shelf.']
# 'start_frame': tensor([1])
# 'end_frame': tensor([809])
# 'clip_start_frame': tensor([0])
# 'clip_end_frame': tensor([165])
# print(len(batch))
fms_list = [e['frames_label'][0].shape[0] for e in batch]
max_fms = max(fms_list)
pack_fms = [(max_fms - f) for f in fms_list]
crop = torchvision.transforms.CenterCrop([224, 224])
batch_tensor = {}
batch_tensor['frames'] = torch.empty(config.bsz, max_fms, 3, 224, 224)
batch_tensor['video_name'] = []
batch_tensor['query_sent'] = []
batch_tensor['start_frame'] = torch.empty(config.bsz)
batch_tensor['end_frame'] = torch.empty(config.bsz)
batch_tensor['clip_start_frame'] = torch.empty(config.bsz)
batch_tensor['clip_end_frame'] = torch.empty(config.bsz)
for i, video in enumerate(batch):
batch_tensor['frames'][i] = torch.empty(max_fms, 3, 224, 224)
for frame in range(video['frames_label'][0].shape[0]):
batch_tensor['frames'][i][frame] = crop(video['frames_label'][0][frame])
torch.nn.functional.pad(batch_tensor['frames'][i], (0, 0, 0, 0, 0, 0, 0, fms_list[i]))
batch_tensor['video_name'].append(video['frames_label'][1])
batch_tensor['query_sent'].append(video['query_sent'])
batch_tensor['start_frame'][i] = video['start_frame']
batch_tensor['end_frame'][i] = video['end_frame']
batch_tensor['clip_start_frame'][i] = video['clip_start_frame']
batch_tensor['clip_end_frame'][i] = video['clip_end_frame']
# print(len(batch_tensor))
return batch_tensor
当指定自己写的collate_fn
函数时(此处命名为my_collate_fn
),会将从dataset
中取样得到的batch_size个对象以list的类型传送过去,然后根据特定的需求对维度和关键字进行处理,返回一个batch,而我得到的batch的结构如下:
#batch structure:
#'frames': bsz, fms, C, H, W
#'video_name': list of bsz
#'query_sent': list of bsz
#'start_frame': bsz
#'end_frame': bsz
#'clip_start_frame': bsz
#'clip_end_frame': bsz
当然,我只是解决了维度和关键字的问题,但是collate_fn的用法绝对不止于此,希望大家今后多多发掘,共同进步~
最后,附一个今天看到的关于torch.utils.data.DataLoader
工作原理的动画,真的超棒!