自动问答 pipeline 可以根据给定的上下文回答问题
可以看到,pipeline 自动选择了在 SQuAD 数据集上训练好的 distilbert-base 模型来完成任务。这里的自动问答 pipeline 实际上是一个抽取式问答模型,即从给定的上下文中抽取答案,而不是生成答案。
根据形式的不同,自动问答 (QA) 系统可以分为三种:
- 抽取式 QA (extractive QA):假设答案就包含在文档中,因此直接从文档中抽取答案;
- 多选 QA (multiple-choice QA):从多个给定的选项中选择答案,相当于做阅读理解题;
- 无约束 QA (free-form QA):直接生成答案文本,并且对答案文本格式没有任何限制
自回归 语言生成是基于如下假设: 一个文本序列的概率分布可以分解为每个词基于其上文的条件概率的乘积
past_key_values
是一个包含模型在之前时间步计算出的键(key)和值(value)的列表,用于在生成新的标记时保持上下文状态。这通常在自回归生成文本时使用,例如在聊天机器人或文本补全任务中。
让我们逐步分析这段代码:
-
past_length = past_key_values[0][0].shape[2]
这行代码计算了past_key_values
中第一个键-值对的第三维度的长度。在 Transformer 模型中,这个长度代表了之前生成的标记的数量。 -
inputs.position_ids += past_length
这行代码更新了inputs.position_ids
。position_ids
是一个与输入标记对应的位置 ID 列表,用于在模型中指示每个标记的位置。由于我们正在添加新的标记,我们需要更新位置 ID 以反映它们在扩展序列中的正确位置。 -
attention_mask = inputs.attention_mask
这行代码获取了当前的注意力掩码。注意力掩码是一个布尔掩码,用于在模型中指示哪些位置可以被用来计算注意力权重。通常,它用于屏蔽序列中的填充(padding)标记,确保模型不会在计算注意力时考虑这些填充标记。 -
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
这行代码扩展了注意力掩码,使其在左侧添加了past_length
个新的True
值。这是因为在生成新的标记时,我们希望模型能够注意到之前生成的所有标记。torch.cat
函数用于在第一个维度(dim=1)上连接两个张量。 -
inputs['attention_mask'] = attention_mask
最后,更新后的注意力掩码被赋值回inputs
字典中的'attention_mask'
键,这样它就可以在模型计算时使用了。
总的来说,这段代码确保了在生成文本时,模型能够正确地处理之前生成的标记,并且能够使用这些标记来计算新的标记的注意力权重。这是实现高效和上下文相关的文本生成的重要步骤