SELF-RAG论文全流程阅读解析

Self-RAG研究学习记录

论文发布的太新太新了,找了小半天也没有相关资料和应用实践,只能啃啃论文,结合自己的理解,以作记录。
论文链接:https://arxiv.org/pdf/2310.11511.pdf
github: https://github.com/AkariAsai/self-rag.git

Self-RAG(Self-Reflective Retrieval-Augmented Generation):自反思的检索增强生成方法。 SELF-RAG 是一种训练任意语言模型(LM)的方法,使其在推理阶段具有可控性。通过让 LM 反思自己的生成过程。

SELF-RAG 使得模型在多种任务上表现优越,如开放领域问答、推理和事实验证。实验结果显示,Self-RAG 在这些任务上的表现优于现有的大型语言模型(如 ChatGPT)以及检索增强模型(如检索增强的 Llama2-chat)。尤其是在事实性和引用准确性方面,Self-RAG 具有显著的提升。

我们目前使用的chatchat架构,以及目前市面上常见的AI知识检索解决方案,都遵循着共通的范式:query+context→LLM。
query 表示用户的输入,context 表示从向量库中检索获得的信息,然后共同输入到 LLM 中,这是一种检索前置的被动的增强方式。

chatchat存在的问题
  • 检索效果依赖 embedding 和检索算法,目前可能检索到相似但是无关信息,反而对输出有负面影响;如:办公电话夜间开机作业怎么办 ,会和 机房电脑夜间开机怎么办 进行匹配 ;
  • 大模型如何利用检索到的信息仍是黑盒的,可能仍存在不准确(甚至生成的文本与检索信息相冲突);如:回答完问题后,会输出 “根据已知信息,无法回答该问题”
  • 对所有任务都无差别强行检索 top_k 个文本片段,当有不相关文本片段被匹配出来时,会影响准确性;
Self-RAG方案理解

相较chatchat的检索前置被动的增强的方案,Self-RAG 是更加主动和智能的实现方式,真正应用上了AI大模型的能力。

SELF-RAG往语言模型的词表中引入了4种新类型的反思令牌,分别是Retriver, IsRel, IsSup, IsUse,对应四种不同的子任务,如下图:

special tokens

以第一个任务Retriever为例, 该任务会判断当前问题下是否需要检索模块支持,对应反思令牌有yes,no跟continue,分别对应不同选择。通过让语言模型生成对应的反思令牌,使得语言模型具备判断是否需要检索以及评判生成结果的能力。

这里值得注意的是,也可以主动由检索模块先来判断是否需要支持,然后把标签提前打好。官方给出的实例代码展示了**主动检索并打Retriever标签 **和 被动由模型识别去打Retriever标签

x:代表问题; d:代表知识片段; y:代表模型根据知识片段生成的回答;

Retrieve任务:模型判断这段话或者这个问题结束后,需不需要去知识库里查询信息。

IsREL任务:检索出来的知识片段y,对我解决问题x有没有帮助。

IsSUP任务:我生成出来的回答y,能不能被知识片段d证明。

IsUSE任务:我生成的y,能不能解决x这个问题。

并且作者把很好的token,用加粗字体标识出来了,如:relevant、full supported、5。

Self-RAG主要运作如下:

  1. 判断是否需要额外检索事实性信息(Retrieve任务),仅当有需要时才召回相关知识片段;
  2. 并发组装每个召回的知识片段,产生prompt(就是用户的问题)+ 一个知识片段的prompt;
  3. 将组装好的各个prompt,再度并发投入到模型中,由模型进行判断用户问题和知识片段的相关性(IsREL任务),并且生成回答y。
  4. 模型生成y后,模型会判断y是否可以由知识片段支撑(IsSUP任务);
  5. 将IsREL任务和IsSUP任务的结果,依次加入到模型前边的输出y中,如果IsREL任务是Relevant,则继续进行评判y是否是有价值的回答(IsUse任务),原理是根据IsREL任务和IsSUP任务的结果标签对y进行打分5~1分,越大越趋近于问题的正确回答,如下图中的Relevant+Supported > Relevant+Partially > Irrelevant 。
  6. 将分数最高的知识片段和该知识片段的IsREL任务和IsSUP任务的反思字段结果,从新投入到大模型中,让模型进行再度理解和整合,输出正确答案。

整体原理如图右所示(图左是我们现在使用的chatchat的实现原理):

在这里插入图片描述

Self-RAG代码理解

从官方入门代码,结合上边的理论进行入手,官方代码如下:

from vllm import LLM, SamplingParams

model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

def format_prompt(input, paragraph=None):
  prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
  if paragraph is not None:
    prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
  return prompt

query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]

# for a query that doesn't require retrieval
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
  print("Model prediction: {0}".format(pred.outputs[0].text))

对核心代码进行逐行分析一下:

sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

这行创建了一个SamplingParams对象,它定义了用于文本生成的一些参数。其中:

  • temperature=0.0: 这是生成文本时的“温度”。温度越低,生成的文本越确定;温度越高,文本越随机。

  • top_p=1.0: 这是一个用于核采样的参数,可以影响生成的文本的多样性。

  • max_tokens=100: 生成的文本的最大令牌数。

  • skip_special_tokens=False: 是否跳过特殊令牌。

def format_prompt(input, paragraph=None):
  prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
  if paragraph is not None:
    prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
  return prompt

这里定义了个format_prompt函数,用于为每个查询进行格式化提示。如果有知识片段的情况下,会在模板后拼接 “[Retrieval]知识内容”,这里的知识片段据我理解应该就是匹配向量库出来的知识片段。

query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]

# for a query that doesn't require retrieval
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
  print("Model prediction: {0}".format(pred.outputs[0].text))

输出:

1、Model prediction: Twitter, Instagram, and WhatsApp are all social media platforms. [No Retrieval]WhatsApp is the odd one out because it is a messaging app, while Twitter and # Instagram are primarily used for sharing photos and videos.[Utility:5]</s> 
    
2、Model prediction: Sure![Retrieval]<paragraph><paragraph>

定义了两个问题,并将问题先调用format_prompt进行知识模板拼装,但是通过这块代码可以看出,这两个问题进行拼装时都没有传知识片段,也就是说不会在这里进行拼接 “[Retrieval]知识内容”

但是根据官方给出的输出答案来看,答案是拼装了 “[Retrieval]”,也就是说,训练好的模型是具备自主判断这个问题需不需要召回知识片段(Retrieve任务),是不是可以直接回答,这代表需要对模型进行完成度很高的训练

而自己写的前置format_prompt函数,估计是为了让模型的的Retrieve任务更加可控,不完全依赖于模型的Retrieve。

当 Self-RAG 不需要检索时,它会在第一个查询中开始生成无需检索的响应。另一方面,Self-RAG 输出第二个标记,因为这个问题需要更细粒度的事实基础。[Retrieve]

对于需要事实基础的查询,可以插入段落。Self-RAG 可以在生成时随时检索和插入段落,只要它们被上下文标记特殊标记包围,就可以识别它们。<paragraph>``</paragraph>

# for a query that needs factual grounding
prompt = format_prompt("Can you tell me the difference between llamas and alpacas?", "The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.")

preds = model.generate([prompt], sampling_params)

print([pred.outputs[0].text for pred in preds])

# ['[Relevant]Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.[Fully supported][Utility:5]</s>']

The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.这段就代表从向量库中匹配出来的知识片段。

将知识片段与问题一起通过format_prompt方法进行组装,然后投入到模型中,获得回答y,以及IsREL任务、IsSUP任务、IsUse任务的结果。可以看到在多个相似答案中,Self-RAG 找到相关的知识文档并生成最正确的答案。

官方只给到这一步,但是从上边的流程图中可以看到,其实还有一步是将知识片段和IsSUP任务、IsUse任务结果,从新投入到模型再度进行生成,然后会组合出额外的完整的回答。这个额外的4部,应该是基于重新投入的知识片段进行处理后推理得出的新总结。

SELF-RAG算法推理

图4: SELF-RAG推理流程图

生成模型M首先会判断当前问题是否需要使用检索,如果需要的话,就会检索召回多个相关文档(图中Retrieve部分),通过并行的方式同时处理多个文档,并生成对应回复,再通过排序选择其中最合适的回复作为最终结果。

因为有反思令牌的存在,所以整个SELF-RAG模型的推理过程更加精确智能,真正应用上了大模型的思维能力。

从官方的这个算法推理图中也可以看出,SELF- RAG中存在了两个模型,其中评判模型C的用途是为了构建生成模型M所需要的训练数据,当生成模型训练完成后,在具体推理时只会用到生成模型M

SELF-RAG模型训练

论文原文链接:https://arxiv.org/pdf/2310.11511.pdf

训练部分的论文原文:

在这里插入图片描述

首先明确一点就是:SELF-RAG的训练过程涉及到两个模型,分别是评判模型Critic跟生成模型Generator(也就是前边各种图中所示的M模型),作者实验阶段的训练数据集是150K个输入输出指令对。

理解一下就是,SELF-RAG模型是基于任意自然语言大模型的,而SELF-RAG所需的4种反思token都不在自然语言大模型的语料库中。我们的模型既然要生成和识别这些token,就需要进行训练了。

作者为了节省成本,提出了一个的方法:写一个标注反思token的例子,然后把自己整理好的所有的自己领域相关的训练语料通过GPT4的API,让GPT4根据例子来进行插入反思token,然后把GPT4吐出来的训练语料+标注好的反思token作为完整的数据集,灌输给自己的M生成模型去学习,学习GPT4是怎么生成这些token的。看似很合理,但是有个问题。就训练语料是根据自己领域来做的,是很庞大的具有专业性质的数据库集合,让GPT4去生成时间和物力成本太大,受限也太大。

然后作者又提出了另一个方案,也是真正使用的方案,就是用一小部分自己专业领域的数据训练语料,去让GPT4生成标注反思token,然后让我们另一个语言模型C做微调。C微调的目的不是去学习文本如何生成,它只是去学习反思token如何生成的。

也就是说,给它一整句话,让他去学习最后生成的反思令牌。这样我们就有了一个专精于生成反思令牌的语言模型,这样再用我们专精于生成反思令牌的语言模型,去调整海量要喂给真正的生成式大语言模型M的数据。这样一方面,我们既获得了类似于GPT4一般的反思token生成能力,另一方面成本也降下来了。这个模型就是Critic,专门负责把训练语料+反思token生成后让大语言模型M去学习的工具模型。

在这里插入图片描述

论文中指出作者使用Open-Instruct processed data和knowledge-intensive datasets作为训练数据,在训练critic模型的时候是分别从这些训练数据里面随机采样一部分,每种大概是4k~20k条数据,共4种,使用提示工程让gpt4给这些训练数据做token的标注,然后再利用这一部分标注之后的数据训练critic模型专门学习生成反思令牌。

至于文本是如何生成的,Critic是完全不在乎的。Critic可以是任何大语言模型,作者为了图方便,应用的是和后边的生成模型M相同的Liama2-7B作为初始模型。

生成模型是如何做的,作者整理好全部的输入输出片段,对任何一个片段,都让Critic模型先判断一下每条语言片段会不会有token,都有哪些token,这样就可以让生成文字的M模型,既学习到了如何生成文本,也知道如何生成反思token,也扩充了M模型的语料库,把反思token融入进了语料库中了。

  • 24
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值