论文地址:https://arxiv.org/abs/2308.04711
git地址:
1 研究背景:
在有限计算资源的情况下如何使用较小的语言模型(LLM)来回答简单推理问题问题。
2 研究对象:
该研究的主要研究对象是较小的语言模型,其目标是评估在有限计算资源情况下,小模型如何通过推理生成答案。研究特别关注使用基于Wikipedia的密集检索系统和LLM生成的解释作为知识来源 (和常规RAG方案不同的是,这里用了LLM根据自身生成的结果也作为推理的依据, knowledge augment,RR Model除了相关性排序,也起到了一个过滤LLM生成的hallucination的作用)
3 研究贡献
-
提出了 Rationale Ranking (RR) 方法:
- 该方法能够根据相关性和真实性对生成的推理(rationale)和检索到的上下文进行打分,并使用这些分数来组合不同的知识来源。通过这种方式,能够筛选出最相关且最真实的内容,进而提高小模型的推理能力。
-
应用了 Retrieval-Augmented Training Datasets (RATD) 方法:
- 该方法通过检索增强的数据集来训练推理模型,使得模型能够在长文本中进行推理,即便这些文本可能包含部分相关的事实或噪声信息。这种训练策略提升了小模型在复杂推理任务中的表现。
-
证明了这些方法在小模型中的有效性:
- 研究表明,使用上述两种方法,较小的语言模型能够在多种数据集上显著超越强基线,尤其是在应对未见问题(unseen questions)的情况下。小模型在如 StrategyQA、CommonsenseQA、ARC-DA 和 IIRC 数据集上的表现大幅提升。
-
展示了小模型与大模型的比较优势:
- 研究证明,小模型在结合两种知识来源(生成的推理和检索到的上下文)后,能够在某些推理任务上与更大的模型(如 BLOOM 175B 和 StableVicuna 13B)相媲美,甚至在一些任务上表现更好。
-
揭示了不同知识来源的优势:
- 研究深入分析了 LLM 生成的推理和基于 Wikipedia 的多跳检索之间的相对优势。LLM 生成的推理在常识推理任务中更为强大,而多跳检索在多步事实推理问题中表现更好。结合两种来源能够显著提升整体表现。
4 研究方法:
4.1 数据集
-
StrategyQA (SQA):
- 特点:SQA 是一个涉及多跳推理的常识问答数据集,答案为“是/否”类型,平均需要 2.33 步推理来得出答案。
- 任务:该数据集要求模型通过多个步骤推导出最终的常识性判断,测试模型的推理链条构建能力。
-
CommonsenseQA (CSQA):
- 特点:这是一个多选的常识问答数据集,问题源自 ConceptNet,给出5个选项,其中多个可能是合理的。
- 任务:模型需要选择最佳答案,重点考察模型在常识推理中的表现。
-
ARC-DA(AI2 Reasoning Challenge - Direct Answer):
- 特点:这是 ARC 的一个子集,其中问题被重新措辞以适应开放域的上下文。问题大多来自科学领域,答案格式为直接回答,而不是选择题。
- 任务:评估模型在科学推理中的表现,尤其是对模型的推理能力和直接回答问题的能力进行测试。
-
IIRC (Incomplete Information Reading Comprehension):
- 特点:这是一个事实性问答数据集,每个问题最初提供一个解释段落,但需要额外的信息才能得出完整的证据,平均需要 1 到 4 次检索。
- 任务:测试模型在信息不完整的情况下如何检索额外信息,并结合已有信息回答问题。
-
Musique:
- 特点:这是一个多跳事实推理数据集,由现有数据集中的单跳问题组合而成,每个问题平均需要 4 步推理。
- 任务:考察模型在复杂的多跳推理问题上的表现。
这些数据集的共同特点是它们涵盖了多种类型的推理任务,包括常识推理、多跳推理、事实检索等。
4.2 模型选型:
4.2.1 Rationale Generation
1、使用的两个模型:
BLOOM: 176B,最大模型之一,并且在推理任务中可以提供非常强的生成能力
StableVicuna:13B小模型,基础模型llama + ChatGPT conversations -> Vicuna, Vicuna + supervised,RLHF -> StableVicuna
2、INT8 和 FP16 的使用:
INT8 版本:StableVicuna 使用了 INT8 精度(8-bit 整数矩阵乘法),这大大减少了模型的内存需求,使得该模型可以在更小的显存中运行(约 18GB),适合资源受限的环境
FP16 版本:相比之下,StableVicuna 的 FP16 精度(16-bit 浮点数)版本占用的显存约为 36GB,虽然占用更多内存,但推理速度更快。因此,研究中对比了两种不同的精度设置,权衡了资源限制与推理速度之间的平衡
3、生成推理的过程
研究采用了 贪婪解码(greedy decoding) 的方式来生成推理解释。使用链式思维(chain-of-thought, COT)提示生成推理过程,并让模型生成出推理解释,之后给出最终答案。生成的推理长度限制为最多 128 个 token
4、少量提示样本:
为了保持模型训练的一致性,使用了相同的少量提示样本(few-shot prompts)来生成推理解释。这使得 BLOOM 和 StableVicuna 的生成推理能够在相似的条件下进行比较
附录中提供了样例
4.2.2 Retrieval
-
多跳检索模型 (n-hop Retrieval Model):
- 文中提到,检索知识的来源是使用一个名为“Iterator”的多跳密集检索模型,该模型最多可以执行4次跳跃(即 n ≤ 4)。这种多跳检索模型能够处理复杂的问题,通过多次检索,从多个相关文档中提取信息以回答问题。
- 例如,模型首先从问题中检索出第一个相关文档(d0),然后使用这个文档和问题一起作为输入,再去检索第二个文档(d1),依次类推,直到检索到足够多的文档为止。
-
两阶段重排序系统:做了复杂的重排设计,处于对多跳检索的考虑, 第一阶段注重精度,确保每个检索到的片段都是高相关的,而第二阶段则关注全面性,确保这些片段在一起能形成一个完整的证据链
- 第一阶段:段落重排序器(Paragraph Reranker)对检索到的段落和句子进行评分,评分的依据是与当前问题的相关性。
- 第二阶段:证据集评分模型(Evidence Set Scoring Model)会对选中的句子进行进一步评分,以判断整个证据集是否足够支持回答问题。
-
检索内容的格式:
- Iterator模型生成的上下文内容是从Wikipedia段落中提取出来的段落片段,这些片段是由得分最高的句子组成,并添加了相邻的上下文句子(如前后的句子)。每个上下文内容以文档的标题开头,然后是所选段落的1到3个句子,多个段落片段组成一个512-token的序列
-
知识来源:
- 文中的检索系统使用的是2020年8月1日的英文Wikipedia数据库,包含约3500万段落。这些检索到的上下文会被用于训练和评估模型,帮助模型从大量文本中提取有用的知识。
4.2.3 Rationale Ranker
Rationale Ranker 模型的构建及其训练过程。这个模型的目的是根据给定问题对上下文进行打分,确保系统选择的上下文不仅相关,而且是真实和可靠的。通过生成正负样本对,并结合 LLM 生成的负例数据,Rationale Ranker 能有效识别虚假信息并为问题提供足够的证据支持
1. Rationale Ranker 模型的输入与输出
- 输入:Rationale Ranker 接收的问题和上下文对(〈q, c〉),其中
q
是问题,c
是候选的上下文片段(可能是从检索系统或 LLM 生成的推理中获取的)。 - 输出:模型为每对输入生成一个得分
s
。这个得分反映了上下文c
在回答问题q
时的相关性和真实性
2.训练目标与数据集构建
- Rationale Ranker 使用 二元交叉熵损失函数(binary cross-entropy objective)进行训练,模型的任务是判断上下文是否为 “真” 或 “假”。如果上下文
c
是真实的且能够完全回答问题q
,那么该对样本会被标记为 1.0,否则标记为 0.0 - 训练数据集包括了正样本和负样本:
- 正样本:这些是能为问题提供真实且相关答案的上下文片段,通常是从数据集中提取的“gold sentences”(即足够的证据支持回答问题的句子)。
- 负样本:负样本是指与问题无关、或不真实的上下文片段。这些负样本通过以下两种方式生成:
- LLM 生成:利用 LLM(如 BLOOM)生成错误或不相关的推理内容作为负样本。
- 合成生成:通过替换 gold sentence 中的某些句子,人工构造无关或虚假的上下文
3. 训练样本构造
- 为了确保模型在训练过程中能够有效学习正负样本之间的差异,每个批次中都会包含针对相同问题
q
的正负样本对(positive and negative c)。这种 共享标准化(shared normalization)策略有助于模型更好地学习如何区分相关与不相关的上下文 - 数据集的构造方法包括从多个公开数据集(如 HotpotQA、FEVER、ARC等)中提取的 gold 句子,结合 LLM 生成的负样本,形成一个丰富的训练集
4. 模型评价
- 在开发集上的评估显示,Rationale Ranker 在检测上下文是否为真实或虚假的准确率很高(达到92.3%),尤其是在区分正负上下文时表现优异。
- 模型还被用于检测 LLM 生成内容中的虚假信息,并与其他模型进行比较(如 GPT-3 和 GPT-4)。实验结果表明,Rationale Ranker 在检测虚假信息方面优于未经过强化学习训练的较大模型
5. 与其他检索系统的对比
Rationale Ranker 与第一阶段的 段落重排序器 类似,都是用于对上下文片段进行评分。但 Rationale Ranker 不仅考虑相关性,还专门训练来检测虚假或不可靠的内容。因此,它在发现上下文中的虚假信息方面表现得更好,而不仅仅是判断上下文的相关性。
4.2.4 Reasoning Models
-
RATD 模型:
- RATD 模型是本文的基线模型,最初是由 Hartill 等人(2023)从预训练的 BART 模型进一步训练得到的。该模型在多任务环境中训练,目标是增强其推理能力,特别是处理长文本中的复杂推理任务。
- RATD 模型训练的数据集分为两类:
-
RATD 数据集:用于训练模型在包含冗余或部分相关的上下文中进行推理的能力。这些上下文可能包含与问题部分相关或不相关的事实,模型需要从中找到有用的证据。Common 数据集:涵盖基础的问答任务,旨在提高模型的基本推理和回答问题的能力。
-
GR 模型:
- GR 模型与 RATD 模型类似,区别在于它使用了一个额外的训练集,称为 GR 数据集(Gold Rationale 数据集)。GR 数据集专注于推理形式的上下文,即通过 Rationale Generation(推理生成)生成的上下文。
- GR 模型被训练为能够处理来自 Rationale-style 上下文的信息。
-
GR+RATD 模型:
- 这是文中主推的模型,结合了 Common 数据集、GR 数据集和 RATD 数据集进行训练,因此具备处理两类上下文的能力(即包含 gold rationale 的推理上下文和检索到的复杂上下文)。
- GR+RATD 模型通过将推理和检索上下文结合,可以更好地处理涉及复杂推理的任务。
5 实验:
数据集和模型已经在4中介绍
Context Combination
LLM生成和检索的文本如何拼接,设计了4种策略
Naïve Concatenation:无视RR score, 直接拼接chunk
Max Score:只取RR score高的chunk
RationaleDefault: 设置阈值比如0.75,如何检索结果RR score高于0.75,就用检索结果作为最终chunk,如果检索结果RR score低于0.75,则用模型生成额chunk
实验结果:小模型stableVicuna在两类训练集,5种检索组合上表现都更好
6 结论
1、使用两种数据源作为推理依据:LLM生成的文本c1和检索文本c2
2、提出两种方法,均能提升推理效果:1、在RATD数据集上训练 2、用Rationale Ranking Model来决定c1和c2的组合方式
3、基于这两种方法,在验证集上较小的13BLLM的效果高于175B的模型
4、LLM在多跳回答上表现较弱,但在常识性问题上较强。
5、关于c1和c2的组合方式,基于RR Model的组合方式表现更好。