【论文阅读】W-RAG: Weakly Supervised Dense Retrieval in RAG for Open-domain Question Answering

论文地址:https://arxiv.org/pdf/2408.08444

论文代码:https://github.com/jmnian/weak_label_f or_rag

背景

这篇文章主要是针对Dense Retrieval微调需要耗费大量人力进行标注的难题,提出了一种利用llm的rerank能力来创建弱标记数据,用于训练检索的向量模型

3个贡献点:

1、提出了W-RAG,一个数据构造方法来训练dense retriever

2、执行了多个实验,证明W-RAG带来了检索和问答上的提升

3、使用了开源模型,并且公开了W-RAG代码

具体方法

W-RAG过程分为三个阶段:

Step1:通过BM25从证据语料库中检索相关段落

Step2:  其次,它使用LLM通过根据生成gournd_truth_answer的probability对检索到的段落重新排序来从中生成弱标签; 因此我理解这里的弱标签其实是指排序顺序。(这里有一个大大的疑惑,为什么不直接用LLM去rerank一下这些passage,而要转而用一种更复杂的方式

Step3: 通过得到的weak label对dense retriever进行训练

Weak-label 生成

生成label的第一步是从corpus中召回一批passage,通过BM25或者dense retriever,然后将query,passage和instruction和answer组装成prompt。 prompt长这样:

(这里也有个疑惑就是,为什么要把answer放到prompt里)

然后通过大模型对该prompt生成answer。

从每个passage中生成ground_truth answer的概率通过以下公式衡量:

对于大模型来说,answer是由多个token组成的,所以生成第a_j个token需要由(passage、

),query, instuction和a_j之前的token决定。每个token的生成的概率直接从模型的logits层就能获得。

另外,由于概率累乘会造成数值下溢(underflow)的问题,因此通过两边log转化为log- likelihood,将概率的乘积转化为对数和,大大降低了下溢的风险。

log-likelihood的值也作为每一个candidate passage的相关度分数,即Figure 1中提到的answer relevance分数。

这里可以去看下源码是如何去做这个计算的

训练dense retriever

上一步中已经获取到了query,passage, query-passage-relevence score, 确定positive sample和negative sampe以后,就可以开始训练模型噜;选择两个dense retriever, DPR和colBERT。

DPR:

DPR utilizes a bi-encoder architecture, where the question and passage are independently mapped to an embedding space through two separate BERT encoders.

DRP用的训练方式是“in-batch negative training“,使用top 1的passage作为positive sample, top 2-n的作为negative sample

loss funtion为:

The scaler 𝛼 is used to amplify the cosine similarity score, usually set at 20 according to the default setting in sentence-transformers 2.

其中R_{q,s_i}指q和s的向量相似度,MNR的分子是positive sample,而分母是all

ColBERT

ColBERT employes a bi-encoder architecture with late interaction, where the question and passage are independently encoded using a shared BERT model to obtain embeddings 𝐸𝑞 and 𝐸𝑠 .

colBERT的训练数据构造,<q, s^+, s^->,其中s^+表示positive sample,s^-表示hard negative sample,loss函数定义如下:

R_{q,s} 同样也是在做相似度计算,但不同与余弦相似度,这里采用了"MaxSim"这种形式,是query和passage之间token级别的相似度计算

实验

知名开源数据集:

MSMARCO、QnA v2.1, NQ, SQuAD, WebQ

每个数据集选取5000 qa pairs和500,000 passages

模型

Generator: Llama3-8B-Instruct

Retriever: DPR (bert-base-uncased、Yibin-Lei/ReContriever), ColBERT(bert-base-uncased)

Baseline

RAG效果对比

Naive:直接用llama3作答

GroundTruth:将ground-truth文档插入到prompt中作答

检索效果对比:

无监督模型:BM25、ColBERT,Contriever,ReContrever

基于GroundTruth训练的模型:DPR、ColBERT

metric

检索:recall

生成:F1、Rouge-L、BLEU-1

实验结果

相比于naive,提升较大,接近于goundtruthbm

BM25相较于llama3-8B在排序能力上的差距

消融实验

对比不同的模型排序能力

对比不同prompt的表现,zero、one,two之间的差别不大

对比不同的文档数量对效果的提升

未来工作

1、探究不同类型的文档,对LLM生成的影响

2、retrievaled results compression

3、RAG robustness

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值