def collate_fn(batch):
"""
计算该batch中的所有sample的最长的input,并且通过末尾补0将其他input的长度向其对齐
"""
global pad_id
input_ids = []
btc_size = len(batch)
# 该batch中最长的input,用于该batch的数据对齐
max_input_len = 0
# 计算该batch中input的最大长度
for btc_idx in range(btc_size):
if max_input_len < len(batch[btc_idx]):
max_input_len = len(batch[btc_idx])
# 使用pad_id对小于max_input_len的input_id进行补全
for btc_idx in range(btc_size):
input_len = len(batch[btc_idx])
input_ids.append(batch[btc_idx])
input_ids[btc_idx].extend([pad_id] * (max_input_len - input_len))
return torch.tensor(input_ids, dtype=torch.long)
该方法以batch为单位,以banch内最长的输入为标准,在对话尾部补充上pad_id=0,将batch内的数据长度调整一致