【ChatBot开发笔记】语料处理——数据整形

def collate_fn(batch):
    """
    计算该batch中的所有sample的最长的input,并且通过末尾补0将其他input的长度向其对齐
    """
    global pad_id
    input_ids = []
    btc_size = len(batch)
    # 该batch中最长的input,用于该batch的数据对齐
    max_input_len = 0  
    # 计算该batch中input的最大长度
    for btc_idx in range(btc_size):
        if max_input_len < len(batch[btc_idx]):
            max_input_len = len(batch[btc_idx])
    # 使用pad_id对小于max_input_len的input_id进行补全
    for btc_idx in range(btc_size):
        input_len = len(batch[btc_idx])
        input_ids.append(batch[btc_idx])
        input_ids[btc_idx].extend([pad_id] * (max_input_len - input_len))
    return torch.tensor(input_ids, dtype=torch.long)

该方法以batch为单位,以banch内最长的输入为标准,在对话尾部补充上pad_id=0,将batch内的数据长度调整一致

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值