如何提升RAG性能?对于黑盒大模型比如GPT4来说,比较合适的是冻住LLM,利用LLM对文档的偏好作为监督信号来微调检索器。
开源地址:https://github.com/NLPJCL/RAG-Retrieval
这里面关键一环是如何得到监督信号,即对于每个用户查询Q和一些文档 𝐷1,𝐷2,…,𝐷𝐾 ,得到Q和每个文档的分数。对于编码器解码器架构来说,可以利用Fusion-In-Decoder(FiD)方法得到每个document相当于query的分数 。对于现在比较流行的解码器架构来说,比较好的思路是利用LM的输出概率来得到监督信号 。
2024.6.4更新:这两天发现GPT3.5/4以及大多数基于API的大模型,并不支持获取输入token的概率,一个可替代的方案是将查询和文档拼接输入大模型得到相关度。
得到监督信号后,利用KL散度将检索器输出的Q-D分数和LM偏好的Q-D分数对齐即可微调检索器。本文面向解码器架构,结合理论和实践介绍第二种思路。
计算LM偏好监督信号
首先详细介绍如何得到监督信号:
- 检索:对于某个问答任务,我们有N个训练问答对 (𝑄,𝐴) ,我们首先对于每个 𝑄 ,用目前的检索器在自己的文档库里检索K个文档(K一般小于等于20),得到文档集合 𝐷′ 。
- 计算logits:我们将其中每个文档 𝐷𝑖 都放入当前的上下文问答模板中,得到提示 𝑃𝑖 。之后我们将 𝑃𝑖 和 𝐴分别用tokenizer编码后的token序列拼接在一起,输入模型。设 𝑃𝑖 编码得到token序列长度为 𝑚𝑝 , 𝐴 编码得到token序列长度为 𝑚𝑎 ,那么最后我们取logits中 [𝑚𝑝−1:𝑚𝑝+𝑚𝑎−2] 位置的logits(编号从0开始)
- 计算概率:我们先将这些logits做softmax归一化,之后挨个取出 𝐴 中每个token的对应概率,并将这些概率取平均作为LM偏好的 𝑄 和 𝐷𝑖 的分数
- 文档分数归一化:最后我们对 𝐷′ 中所有LM反馈的文档分数进行softmax归一化,就得到了最终的Q-D相似度。如下图所示
KL散度微调检索器
有了监督信号,我们利用如下的KL散度损失,用检索器输出的Q-D相似度源分布来拟合LM偏好的相似度目标分布。
实践
我们基于 RAG-Retrieval ,实现基于LLM偏好来监督RAG检索器微调。先使用如下命令克隆和安装rag-retrieval包。
git clone https://github.com/NLPJCL/RAG-Retrieval
pip install rag-retrieval
cd examples/synthetic_data
本文的监督信号获取依托于FlashRAG,旨在给读者提供一个可参考的路径,读者可以使用自己需要的LLM和数据来完成自己的监督信号获取。关于此库的基本用法,读者可以先参考笔者的前一篇文章,下文将假设读者已经掌握这篇博客的知识,略去一些评估细节。
https://blog.csdn.net/weixin_45783724
1.索引建立
我们先按照FlashRAG的文档安装好flashrag,以及建立 e5-base-v2 模型的 wiki-18.jsonl 文档的faiss索引。
2.数据集构建
修改 flashrag_config.yaml 中以下内容:
- index_path 改为上一步得到的索引文件的目录
- llama2-7B-chat 的路径改为你需要的LLM,理论上任意hf的LLM都可以适用。这里我们使用 meta-llama/Meta-Llama-3-8B-Instruct
- data_dir 改为 FlashRAG_datasets 下载到本地的路径
- corpus_path 改为 wiki-18.jsonl 对应的路径
这里我们以NQ的测试集为例,得到基于LLaMA-3-Instruct的概率得到的数据集
使用如下命令完成数据集的建立:
python3 get_lm_probs_dataset.py \ # 计算LM概率代码参考get_lm_probs_dataset.py
--dataset_name nq \
--split test \
--num 4000 \ # 查询数量
--gpu_id 0 \
--output lmsft.jsonl \ # jsonl 输出路径
--topk 20
得到的jsonl文件数据格式如下,query为用户查询,pos为若干文档,scores为每个文档相对于query的分数
{"query":"xxx", "pos":["yyy" ,"zzz"], "scores": [0.2, 0.8]}
...
3.训练和评估
我们使用train/train_embeddings.py使用此数据进行训练,即可得到面向NQ任务的专属检索器,使用如下命令
cd ../../rag_retrieval/train/embedding
CUDA_VISIBLE_DEVICES="0" python3 train_embedding.py \
--model_name_or_path "intfloat/e5-base-v2" \
--dataset "../../../examples/synthetic_data/build/methods/lmsft.jsonl" \
--output_dir "./output/lmsft_example" \
--batch_size 64 \
--lr 2e-5 \
--epochs 5 \
--save_on_epoch_end 1 \
--gradient_accumulation_steps 1 \
--log_with 'wandb' \
--warmup_proportion 0.1 \
--neg_nums 15 \
--temperature 0.01 \
--query_max_len 128 \
--passage_max_len 512
我们将FlashRAG中的检索器地址换成我们训练好的检索器,并重新建立索引在NQ上测试,会得到下表中的结果,微调后的模型在NQ测试的表现明显优于原方法,在几个方法Naive-RAG,REPLUG ,Iter-Retgen ,SURE 基础上都提升了5个点左右,说明整套流程确实可以一定程度提升RAG性能。
读者可以根据具体业务,选择相应的数据,用这套流程进行微调,从而提升自己的RAG算法性能
大家好,我是NLP研究者BrownSearch,如果你觉得本文对你有帮助的话,不妨点赞或收藏支持我的创作,您的正反馈是我持续更新的动力!如果想了解更多LLM/检索的知识,记得关注我!
这里也向读者推荐本文中提到的笔者最近参与的两个项目,两个项目的作者都很nice,欢迎大家一起来贡献代码!
第一个是能够以统一方式推理和微调各种检索模型的RAG-Retrieval:
https://github.com/NLPJCL/RAG-Retrievalgithub.com/NLPJCL/RAG-Retrieval
第二个是能够帮助科研人员快速且公平的搭建和评估RAG算法的FlashRAG:
https://github.com/RUC-NLPIR/FlashRAGgithub.com/RUC-NLPIR/FlashRAG