pytorch–DataLoader的collate_fn参数
'''
>>>a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b) # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
>>> zip(a,c) # 元素个数与最短的列表一致
[(1, 4), (2, 5), (3, 6)]
>>> zip(*zipped) # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
[(1, 2, 3), (4, 5, 6)]
'''
def seq_collate(data):
(obs_seq_list, pred_seq_list, obs_seq_rel_list, pred_seq_rel_list,
non_linear_ped_list, loss_mask_list) = zip(*data)
_len = [len(seq) for seq in obs_seq_list]
cum_start_idx = [0] + np.cumsum(_len).tolist()
seq_start_end = [[start, end]
for start, end in zip(cum_start_idx, cum_start_idx[1:])]
# Data format: batch, input_size, seq_len
# LSTM input format: seq_len, batch, input_size
obs_traj = torch.cat(obs_seq_list, dim=0).permute(2, 0, 1)
pred_traj = torch.cat(pred_seq_list, dim=0).permute(2, 0, 1)
obs_traj_rel = torch.cat(obs_seq_rel_list, dim=0).permute(2, 0, 1)
pred_traj_rel = torch.cat(pred_seq_rel_list, dim=0).permute(2, 0, 1)
non_linear_ped = torch.cat(non_linear_ped_list)
loss_mask = torch.cat(loss_mask_list, dim=0)
seq_start_end = torch.LongTensor(seq_start_end)
out = [
obs_traj, pred_traj, obs_traj_rel, pred_traj_rel, non_linear_ped,
loss_mask, seq_start_end
]
return tuple(out)