transformer

自动问答 pipeline 可以根据给定的上下文回答问题

可以看到,pipeline 自动选择了在 SQuAD 数据集上训练好的 distilbert-base 模型来完成任务。这里的自动问答 pipeline 实际上是一个抽取式问答模型,即从给定的上下文中抽取答案,而不是生成答案。

根据形式的不同,自动问答 (QA) 系统可以分为三种:

  • 抽取式 QA (extractive QA):假设答案就包含在文档中,因此直接从文档中抽取答案;
  • 多选 QA (multiple-choice QA):从多个给定的选项中选择答案,相当于做阅读理解题;
  • 无约束 QA (free-form QA):直接生成答案文本,并且对答案文本格式没有任何限制

自回归 语言生成是基于如下假设: 一个文本序列的概率分布可以分解为每个词基于其上文的条件概率的乘积 

past_key_values 是一个包含模型在之前时间步计算出的键(key)和值(value)的列表,用于在生成新的标记时保持上下文状态。这通常在自回归生成文本时使用,例如在聊天机器人或文本补全任务中。

让我们逐步分析这段代码:

  1. past_length = past_key_values[0][0].shape[2] 这行代码计算了 past_key_values 中第一个键-值对的第三维度的长度。在 Transformer 模型中,这个长度代表了之前生成的标记的数量。

  2. inputs.position_ids += past_length 这行代码更新了 inputs.position_idsposition_ids 是一个与输入标记对应的位置 ID 列表,用于在模型中指示每个标记的位置。由于我们正在添加新的标记,我们需要更新位置 ID 以反映它们在扩展序列中的正确位置。

  3. attention_mask = inputs.attention_mask 这行代码获取了当前的注意力掩码。注意力掩码是一个布尔掩码,用于在模型中指示哪些位置可以被用来计算注意力权重。通常,它用于屏蔽序列中的填充(padding)标记,确保模型不会在计算注意力时考虑这些填充标记。

  4. attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) 这行代码扩展了注意力掩码,使其在左侧添加了 past_length 个新的 True 值。这是因为在生成新的标记时,我们希望模型能够注意到之前生成的所有标记。torch.cat 函数用于在第一个维度(dim=1)上连接两个张量。

  5. inputs['attention_mask'] = attention_mask 最后,更新后的注意力掩码被赋值回 inputs 字典中的 'attention_mask' 键,这样它就可以在模型计算时使用了。

总的来说,这段代码确保了在生成文本时,模型能够正确地处理之前生成的标记,并且能够使用这些标记来计算新的标记的注意力权重。这是实现高效和上下文相关的文本生成的重要步骤

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值