项目分享|RAG-Retrieval库实现基于LLM偏好监督RAG检索器微调

如何提升RAG性能?对于黑盒大模型比如GPT4来说,比较合适的是冻住LLM,利用LLM对文档的偏好作为监督信号微调检索器

开源地址:https://github.com/NLPJCL/RAG-Retrieval

这里面关键一环是如何得到监督信号,即对于每个用户查询Q和一些文档 𝐷1,𝐷2,…,𝐷𝐾 ,得到Q和每个文档的分数。对于编码器解码器架构来说,可以利用Fusion-In-Decoder(FiD)方法得到每个document相当于query的分数 。对于现在比较流行的解码器架构来说,比较好的思路是利用LM的输出概率来得到监督信号 。

2024.6.4更新:这两天发现GPT3.5/4以及大多数基于API的大模型,并不支持获取输入token的概率,一个可替代的方案是将查询和文档拼接输入大模型得到相关度

得到监督信号后,利用KL散度检索器输出的Q-D分数LM偏好的Q-D分数对齐即可微调检索器。本文面向解码器架构,结合理论和实践介绍第二种思路。

计算LM偏好监督信号

首先详细介绍如何得到监督信号:

  1. 检索:对于某个问答任务,我们有N个训练问答对 (𝑄,𝐴) ,我们首先对于每个 𝑄 ,用目前的检索器在自己的文档库里检索K个文档(K一般小于等于20),得到文档集合 𝐷′ 。
  2. 计算logits:我们将其中每个文档 𝐷𝑖 都放入当前的上下文问答模板中,得到提示 𝑃𝑖 。之后我们将 𝑃𝑖 和 𝐴分别用tokenizer编码后的token序列拼接在一起,输入模型。设 𝑃𝑖 编码得到token序列长度为 𝑚𝑝 , 𝐴 编码得到token序列长度为 𝑚𝑎 ,那么最后我们取logits中 [𝑚𝑝−1:𝑚𝑝+𝑚𝑎−2] 位置的logits(编号从0开始)
  3. 计算概率:我们先将这些logits做softmax归一化,之后挨个取出 𝐴 中每个token的对应概率,并将这些概率取平均作为LM偏好的 𝑄 和 𝐷𝑖 的分数
  4. 文档分数归一化:最后我们对 𝐷′ 中所有LM反馈的文档分数进行softmax归一化,就得到了最终的Q-D相似度。如下图所示

img

KL散度微调检索器

有了监督信号,我们利用如下的KL散度损失,用检索器输出的Q-D相似度源分布来拟合LM偏好的相似度目标分布。

img

实践

我们基于 RAG-Retrieval ,实现基于LLM偏好来监督RAG检索器微调。先使用如下命令克隆和安装rag-retrieval包。

git clone https://github.com/NLPJCL/RAG-Retrieval
pip install rag-retrieval
cd examples/synthetic_data

本文的监督信号获取依托于FlashRAG,旨在给读者提供一个可参考的路径,读者可以使用自己需要的LLM和数据来完成自己的监督信号获取。关于此库的基本用法,读者可以先参考笔者的前一篇文章,下文将假设读者已经掌握这篇博客的知识,略去一些评估细节。

https://blog.csdn.net/weixin_45783724

1.索引建立

我们先按照FlashRAG的文档安装好flashrag,以及建立 e5-base-v2 模型的 wiki-18.jsonl 文档的faiss索引。

2.数据集构建

修改 flashrag_config.yaml 中以下内容:

  • index_path 改为上一步得到的索引文件的目录
  • llama2-7B-chat 的路径改为你需要的LLM,理论上任意hf的LLM都可以适用。这里我们使用 meta-llama/Meta-Llama-3-8B-Instruct
  • data_dir 改为 FlashRAG_datasets 下载到本地的路径
  • corpus_path 改为 wiki-18.jsonl 对应的路径

这里我们以NQ的测试集为例,得到基于LLaMA-3-Instruct的概率得到的数据集

使用如下命令完成数据集的建立:

 python3 get_lm_probs_dataset.py \  # 计算LM概率代码参考get_lm_probs_dataset.py
 --dataset_name nq \
 --split test \
 --num 4000 \ # 查询数量
 --gpu_id 0 \
 --output lmsft.jsonl \ # jsonl 输出路径
 --topk 20

得到的jsonl文件数据格式如下,query为用户查询,pos为若干文档,scores为每个文档相对于query的分数

 {"query":"xxx", "pos":["yyy" ,"zzz"], "scores": [0.2, 0.8]}
 ...

3.训练和评估

我们使用train/train_embeddings.py使用此数据进行训练,即可得到面向NQ任务的专属检索器,使用如下命令

cd ../../rag_retrieval/train/embedding
CUDA_VISIBLE_DEVICES="0" python3 train_embedding.py  \
--model_name_or_path "intfloat/e5-base-v2" \
--dataset "../../../examples/synthetic_data/build/methods/lmsft.jsonl" \
--output_dir "./output/lmsft_example" \
--batch_size 64 \
--lr 2e-5 \
--epochs 5 \
--save_on_epoch_end 1 \
--gradient_accumulation_steps 1  \
--log_with 'wandb' \
--warmup_proportion 0.1 \
--neg_nums 15 \
--temperature 0.01 \
--query_max_len 128 \
--passage_max_len 512

我们将FlashRAG中的检索器地址换成我们训练好的检索器,并重新建立索引在NQ上测试,会得到下表中的结果,微调后的模型在NQ测试的表现明显优于原方法,在几个方法Naive-RAG,REPLUG ,Iter-Retgen ,SURE 基础上都提升了5个点左右,说明整套流程确实可以一定程度提升RAG性能。

读者可以根据具体业务,选择相应的数据,用这套流程进行微调,从而提升自己的RAG算法性能

img


大家好,我是NLP研究者BrownSearch,如果你觉得本文对你有帮助的话,不妨点赞收藏支持我的创作,您的正反馈是我持续更新的动力!如果想了解更多LLM/检索的知识,记得关注我!

这里也向读者推荐本文中提到的笔者最近参与的两个项目,两个项目的作者都很nice,欢迎大家一起来贡献代码!

第一个是能够以统一方式推理和微调各种检索模型的RAG-Retrieval:

https://github.com/NLPJCL/RAG-Retrievalgithub.com/NLPJCL/RAG-Retrieval

第二个是能够帮助科研人员快速且公平的搭建和评估RAG算法的FlashRAG:

https://github.com/RUC-NLPIR/FlashRAGgithub.com/RUC-NLPIR/FlashRAG

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值