文章指出,传统RAG通过向量检索排序召回与Query相关的片段,通过prompt生成回复,LLMs与检索器之间存在语义鸿沟(LLMs难以有效利用检索器提供的信息)。下面来看看这篇文章引入检索信息增强RAG性能的trick。
方法
检索特征提取
在 R 2 A G R^{2} AG R2AG 中,首先从检索器 f R f_{R} fR 获取语义表示:
- 查询编码表示: q q q: x q = f R ( q ) x^{q}=f_{R}(q) xq=fR(q)
- 文档编码表示 d d d: x d = f R ( d ) x^{d}=f_{R}(d) xd=fR(d)。
这样存在一个问题,这些表示不能直接使用,因为单一的表示无法捕捉到用于LLM生成的交互特征。
因此,为了适应各种检索器,需要将不同空间中的表示转换为统一格式的特征。提出三种相似计算方法来对这些表示进行对齐,从而得到检索特征。
-
相关性得分
相关性得分 r i r_{i} ri 是查询和第 i i i 个文档之间的相关性,也用于对文档进行排序。
-
前例相似性得分
前例相似性得分计算的是第 i i i 个文档表示与其在排名列表中的前例加权表示之间的相似性。
-
邻居相似性得分
邻居相似性得分计算的是第 i i i 个文档表示与其相邻表示之间的平均相似性。
这些得分通过相似性函数(如点积或余弦相似性)计算得出。具体的公式如下:
r i = sim ( x q , x i d ) , r_i = \operatorname{sim}(x^q, x_i^d), ri=sim(xq,xid),
$$
\gamma_i = \operatorname{sim}\left(x_i^d, \sum_{j=1}^{i-1} w_j \cdot x_j^d\right),\quad w_j = \frac{\exp(r_j)}{\sum_{\ell=1}^k \exp(r_\ell)}
$$
ζ i = { sim ( x 1 d , x 2 d ) , i = 1 sim ( x i − 1 d , x i d ) + sim ( x i d , x i + 1 d ) 2 , i ∈ [ 2 , k ) , sim ( x k − 1 d , x k d ) , i = k , \zeta_i = \begin{cases} \operatorname{sim}(x_1^d, x_2^d), & i=1 \\ \frac{\operatorname{sim}(x_{i-1}^d, x_i^d) + \operatorname{sim}(x_i^d, x_{i+1}^d)}{2}, & i \in [2, k), \\ \operatorname{sim}(x_{k-1}^d, x_k^d), & i=k \end{cases}, ζi=⎩ ⎨ ⎧sim(x1d,x2d),2sim(xi−1d,xid)+sim(xid,xi+1d),sim(xk−1d,xkd),i=1i∈[2,k),i=k,
其中, γ i \gamma_{i} γi 表示第 i i i 个文档与其在排名列表中的前例之间的相似性, r i r_{i} ri 是查询与第 i i i 个文档之间的相关性。
最后,将这三个特征拼接起来作为输入:
input i = { r i , γ i , ζ i } , \text{input}_i = \{r_i, \gamma_i, \zeta_i\}, inputi={ri,γi,ζi},
然后将特征列表 { input i } i = 1 k \{\text{input}_i\}_{i=1}^{k} {inputi}i=1k 输入到 R 2 R^{2} R2-Former 中,以进一步挖掘检索信息。
R 2 R^{2} R2-Former
R 2 R^{2} R2-Former 是 R 2 A G R^{2} AG R2AG 框架中引入的一个可训练模块,目的是弥合检索器和LLM之间的语义鸿沟。 R 2 R^{2} R2-Former 被设计为接受列表特征作为输入,并输出检索信息。
输入列表 { input i } i = 1 k \{\text{input}_i\}_{i=1}^{k} {inputi}i=1k, R 2 R^{2} R2-Former 处理输入过程公式如下:
H = f att ( f → h 1 ( { input i } i = 1 k ) + p ) , H = f_{\text{att}} \left( f_{\rightarrow h_{1}} \left( \{\text{input}_i\}_{i=1}^{k}\right) + p \right), H=fatt(f→h1({inputi}i=1k)+p),
其中:
- $ f_{\text{att}} $ 是具有 $ h_{1} $ 隐藏维度的 Transformer 编码器,
- $ f_{\rightarrow h_{1}} $ 是一个线性映射层,
- $ p \in R^{k \times h_{1}} $ 表示可训练的位置嵌入。
这个模块比较好理解,这一步通过利用自注意力机制来增强对检索器提供的列表特征的理解。
检索感知提示
步骤:
-
我们使用一个投影层将检索信息线性变换到与 LLM 的 token 嵌入层相同的维度:
E R = f → h 2 ( H ) = { e i R } i = 1 k , E^R = f_{\rightarrow h_2}(H) = \{e_i^R\}_{i=1}^k, ER=f→h2(H)={eiR}i=1k, -
使用 LLM 的分词器对查询和文档进行分词,并将其转换为嵌入。
E d = f emb ( t d ) = { e j d } j = 1 n d , E^{d} = f_{\text{emb}}(t^{d}) = \{e_{j}^{d}\}_{j=1}^{n_{d}}, Ed=femb(td)={ejd}j=1nd,
其中 $ f_{\text{emb}} $ 是 LLM 的 token 嵌入层,$ E^{d} \in R^{n_{d} \times h_{2}} $ 是文档 d d d 的嵌入。 -
检索信息的嵌入:为了对每个文档进行细致的分析,相应的检索信息嵌入被添加到每个文档嵌入的前面。这些嵌入作为外部知识,起到锚点的作用,引导 LLM 关注有用的文档。最终的输入嵌入可以排列如下:
其中 $ e_{i}^{R} $ 表示第 i i i 个文档的检索信息嵌入。通过这种方式,相应文档的检索信息可以很好地混合在一起,减少了 LLM 处理所有文档的负担。
-
生成响应:
y ^ = f G ( E ) , \hat{y} = f_{G}(E), y^=fG(E),
其中 $ \hat{y} $ 表示 LLM 生成的最终结果。
这一模块主要是将检索信息作为额外的知识输入,增强了 LLM 对文档的理解能力。
训练策略
主要是训练 R 2 R^{2} R2-Former 和 LLM 的对齐训练。
-
训练 R 2 R^{2} R2-Former
R 2 R^{2} R2-Former 是一个查询-文档匹配任务,是一个二分类任务:
s ^ = f → 1 ( H ) = { s ^ i } i = 1 k , \hat{s} = f_{\rightarrow 1}(H) = \{ \hat{s}_i \}_{i=1}^k, s^=f→1(H)={s^i}i=1k,
其中 f → 1 f_{\rightarrow 1} f→1 是一个二分类头,输出文档的相关性预测 s ^ \hat{s} s^。支持 s = { s i } i = 1 k s = \{s_i\}_{i=1}^k s={si}i=1k 是文档的真实标签,交叉熵作为损失函数,定义为:
L Q D M ( s , s ^ ) = − ∑ i = 1 k s i log ( s ^ i ) + ( 1 − s i ) log ( 1 − s ^ i ) . \mathcal{L}_{QDM}(s, \hat{s}) = -\sum_{i=1}^{k} s_i \log(\hat{s}_i) + (1 - s_i) \log(1 - \hat{s}_i). LQDM(s,s^)=−i=1∑ksilog(s^i)+(1−si)log(1−s^i).
LLM 的对齐训练
语言建模损失 L L M \mathcal{L}_{LM} LLM
联合训练
联合训练使得 R 2 R^{2} R2-Former 能够更好地理解来自检索器的列表特征,确保检索信息可以被 LLM 深入解释。
总体损失:
L = L Q D M + L L M . \mathcal{L} = \mathcal{L}_{QDM} + \mathcal{L}_{LM}. L=LQDM+LLM.
文中, R 2 A G R^{2} AG R2AG 提供了仅训练 R 2 R^{2} R2-Former 而冻结 LLM ,或同时训练
实验
参考文献
- https://arxiv.org/pdf/2406.13249v2