MemoRAG
简介
MemoRAG出自2024年9月的论文《Memorag: Moving towards next-gen rag via memory-inspired knowledge discovery》 (github),它提出了长期记忆的概念,让一个轻量、长上下文窗口的LLM对数据集进行记忆后,针对用户的查询任务生成线索和回答草稿,再用这些信息去检索出相关数据库中与问题相关的信息,最后让一个更强大的LLM根据指令和检索出的信息生成最终的回答。
实现思路
标准的RAG框架可以用如下公式表示,式中的
Θ
(
⋅
)
\Theta(\cdot)
Θ(⋅)表示生成模型,
Γ
(
⋅
)
\Gamma(\cdot)
Γ(⋅)表示检索模型,q表示查询,
C
\mathcal{C}
C是从数据集
D
\mathcal{D}
D中检索的上下文,
Y
\mathcal{Y}
Y是最终的答案。
Y
=
Θ
(
q
,
C
∣
θ
)
,
C
=
Γ
(
q
,
D
∣
γ
)
(
1
)
\mathcal{Y} = \Theta(q, \mathcal{C}\ |\ \theta), \quad \mathcal{C} = \Gamma(q, \mathcal{D}\ |\ \gamma) \qquad (1)
Y=Θ(q,C ∣ θ),C=Γ(q,D ∣ γ)(1)
MemoRAG则包含一个记忆模型
Θ
m
e
m
(
⋅
)
\Theta_{mem}(\cdot)
Θmem(⋅)来作为查询q和数据集
D
\mathcal{D}
D之间的语义桥梁,如下面公式所表示,式中的y是由记忆模型生成的阶段性答案,这个答案可能不完整或者缺乏细节,但是它可以作为指导相关上下文的线索。
Y
=
Θ
(
q
,
C
∣
θ
)
,
C
=
Γ
(
y
,
D
∣
γ
)
,
y
=
Θ
m
e
m
(
q
,
D
∣
θ
m
e
m
)
(
2
)
\mathcal{Y} = \Theta(q, \mathcal{C}\ |\ \theta), \quad \mathcal{C} = \Gamma(y, \mathcal{D}\ |\ \gamma), \quad y = \Theta_{mem}(q, \mathcal{D}\ |\ \theta_{mem}) \qquad (2)
Y=Θ(q,C ∣ θ),C=Γ(y,D ∣ γ),y=Θmem(q,D ∣ θmem)(2)
阶段性答案y的形式与具体任务相关(根据任务来设计不同的prompt),比如,对于QA任务,如果query不明确,阶段性答案y可能包括一些中间步骤,比如生成更明确和消歧的替代query,以及数据集中对最终问答有用的文本证据;对于摘要任务,阶段性答案y可能包括一些从上下文中得到的要点和概念。
记忆模块
记忆模块将输入的token变成更小的记忆token,同时保留文本里的关键语义信息。
对于普通的基于transformer网络的模型
Θ
(
⋅
)
\Theta(\cdot)
Θ(⋅),假设输入
X
\mathcal{X}
X包括n个token:
X
=
{
x
1
,
⋯
,
x
n
}
\mathcal{X} = \{x_1, \cdots, x_n \}
X={x1,⋯,xn},每个注意力层可以表示成如下:
Q
=
X
W
Q
,
K
=
X
W
K
,
V
=
X
W
V
(
3
)
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
,
Θ
(
X
)
=
Attention
(
Q
,
K
,
V
)
(
4
)
\mathcal{Q} = \mathcal{X} \mathbf{W}_{\mathcal{Q}} , \quad \mathcal{K}= \mathcal{X} \mathbf{W}_{\mathcal{K}}, \quad \mathcal{V}=\mathcal{X}\mathbf{W}_{\mathcal{V}} \qquad (3)\\ \text{Attention}(\mathcal{Q}, \mathcal{K}, \mathcal{V}) = \text{softmax}\left(\frac{\mathcal{Q} \mathcal{K}^T}{\sqrt{d_k}}\right) \mathcal{V},\quad \Theta(\mathcal{X}) = \text{Attention}(\mathcal{Q}, \mathcal{K}, \mathcal{V}) \qquad (4)
Q=XWQ,K=XWK,V=XWV(3)Attention(Q,K,V)=softmax(dkQKT)V,Θ(X)=Attention(Q,K,V)(4)
上式中的
W
Q
\mathbf{W}_{\mathcal{Q}}
WQ、
W
K
\mathbf{W}_{\mathcal{K}}
WK、
W
V
\mathbf{W}_{\mathcal{V}}
WV是attention层的query、key、value的投影权重矩阵,
d
k
d_k
dk是key向量的维度,论文认为输入经过多个transformer层后得到对输入
X
\mathcal{X}
X更充分的理解,这有点类似于人类的短期记忆主要包括当前收到的信息。将这个过程记作
X
=
Θ
(
X
)
\boldsymbol{\mathcal{X}} = \Theta(\mathcal{X})
X=Θ(X) ,
X
\boldsymbol{\mathcal{X}}
X是输出序列
X
\mathcal{X}
X的隐状态,
Θ
(
⋅
)
\Theta(\cdot)
Θ(⋅)可以是任何LLM。
为了将短期记忆变成长期记忆,Memorag作者引入了记忆token
x
m
x^m
xm作为LLM长期记忆的信息载体。设LLM
Θ
(
⋅
)
\Theta(\cdot)
Θ(⋅)的工作上下文窗口长度为l,在每个上下文窗口后添加k个记忆token,记为(其实有点疑惑为什么在k个记忆字符后还表示了 一些token):
X
=
{
x
1
,
⋯
,
x
l
,
x
1
m
,
⋯
,
x
k
m
,
x
l
+
1
,
⋯
}
,
k
≪
l
(
5
)
\mathcal{X} = \{x_1, \cdots, x_l, x^m_1, \cdots, x^m_k, x_{l+1}, \cdots \}, k\ll l \qquad (5)
X={x1,⋯,xl,x1m,⋯,xkm,xl+1,⋯},k≪l(5)
将attention层的公式(4)修改为下式:
Q
m
=
X
W
Q
m
,
K
m
=
X
W
K
m
,
V
m
=
X
W
V
m
(
6
)
Attention
(
Q
,
K
,
V
)
=
softmax
(
[
Q
;
Q
m
]
[
K
;
K
m
;
K
cache
m
]
T
d
k
)
[
V
;
V
m
;
V
cache
m
]
(
7
)
\mathcal{Q}^m = \mathcal{X} \mathbf{W}_{\mathcal{Q}^m} , \quad \mathcal{K}^m= \mathcal{X} \mathbf{W}_{\mathcal{K}^m}, \quad \mathcal{V}^m=\mathcal{X}\mathbf{W}_{\mathcal{V}^m} \qquad (6)\\ \text{Attention}(\mathcal{Q}, \mathcal{K}, \mathcal{V}) = \text{softmax}\left(\frac{[\mathcal{Q};\mathcal{Q}^m] [\mathcal{K};\mathcal{K}^m;\mathcal{K}^m_{\text{cache}}]^T}{\sqrt{d_k}}\right) [\mathcal{V};\mathcal{V}^m; \mathcal{V}^m_{\text{cache}}] \qquad (7)
Qm=XWQm,Km=XWKm,Vm=XWVm(6)Attention(Q,K,V)=softmax(dk[Q;Qm][K;Km;Kcachem]T)[V;Vm;Vcachem](7)
上式中
W
Q
m
\mathbf{W}_{\mathcal{Q}^m}
WQm、
W
K
m
\mathbf{W}_{\mathcal{K}^m}
WKm、
W
V
m
\mathbf{W}_{\mathcal{V}^m}
WVm是针对记忆信息而初始化的权重矩阵,它们将记忆token
x
m
x^m
xm变成query、key、query向量:
Q
m
\mathcal{Q}^m
Qm、
K
m
\mathcal{K}^m
Km、
V
m
\mathcal{V}^m
Vm,
K
cache
m
\mathcal{K}^m_{\text{cache}}
Kcachem和
V
cache
m
\mathcal{V}^m_{\text{cache}}
Vcachem是之前的memory token的缓存。将记忆token记为
X
m
\boldsymbol{\mathcal{X}}^m
Xm,上述转换过程记为
X
m
=
Θ
mem
(
X
)
\boldsymbol{\mathcal{X}}^m = \Theta_{\text{mem}}(\boldsymbol{\mathcal{X}})
Xm=Θmem(X)。 对于l个原始token
{
x
1
,
⋯
,
x
l
}
\{x_1, \cdots, x_l\}
{x1,⋯,xl},经过个多次由公式(7)定义的注意力机制过程后,将得到隐状态
X
[
0
:
l
]
=
{
x
1
,
⋯
,
x
l
,
x
1
m
,
⋯
,
x
k
m
}
\boldsymbol{X}_{[0:l]} = \{\boldsymbol{x}_1, \cdots, \boldsymbol{x}_l, \boldsymbol{x}_1^m, \cdots, \boldsymbol{x}_k^m\}
X[0:l]={x1,⋯,xl,x1m,⋯,xkm},其中的
{
x
1
,
⋯
,
x
l
}
\{\boldsymbol{x}_1, \cdots, \boldsymbol{x}_l\}
{x1,⋯,xl} 是原始token的隐状态,
{
x
1
m
,
⋯
,
x
k
m
}
\{\boldsymbol{x}_1^m, \cdots, \boldsymbol{x}_k^m\}
{x1m,⋯,xkm}是记忆token的隐状态,在记忆形成后,类似于人类记忆的遗忘过程,l个原始token的KV缓存被丢弃。经过n个上下文窗口后,MemoRAG将
X
\mathcal{X}
X中所有的原始token转化成了记忆token,将其记为
Θ
(
X
)
→
Θ
mem
(
X
)
=
{
x
1
,
1
m
,
⋯
,
x
1
,
k
m
,
⋯
,
x
n
,
k
m
}
\Theta(\mathcal{X}) \rightarrow \Theta_{\operatorname{mem}}(\mathcal{X})=\left\{\boldsymbol{x}_{1,1}^m, \cdots, \boldsymbol{x}_{1, k}^m, \cdots, \boldsymbol{x}_{n, k}^m\right\}
Θ(X)→Θmem(X)={x1,1m,⋯,x1,km,⋯,xn,km},意味着从输入
X
\mathcal{X}
X得到了全局记忆
X
m
\boldsymbol{\mathcal{X}}^m
Xm。
记忆模块的训练:因为记忆模块引入了权重矩阵 W Q m \mathbf{W}_{\mathcal{Q}^m} WQm、 W K m \mathbf{W}_{\mathcal{K}^m} WKm、 W V m \mathbf{W}_{\mathcal{V}^m} WVm,在训练时只更新这些权重,LLM的其他参数保持冻结状态。训练时分为两个阶段:(1)预训练:从RedPajama数据集中随机采样长文本来预训练模型,让记忆模块学习如何从原始文本形成记忆;(2)Supervised Finetuning(SFT):用任务相关SFT数据让MemoRAG基于记忆生成任务相关的线索。(SFT训练数据集是让能力比较强的LLM针对收集的长文本生成问答对以及线索,再人工筛选得到的)
记忆模型的目标函数如下,目的是在给定之前的记忆token
{
x
1
,
1
m
,
⋯
,
x
i
−
1
,
k
i
−
1
m
}
\{x^m_{1,1}, \cdots, x^m_{i-1,k_{i-1}} \}
{x1,1m,⋯,xi−1,ki−1m}和新的原始token
{
x
i
,
1
,
⋯
,
x
i
,
j
−
1
}
\{x_{i,1}, \cdots, x_{i,j-1}\}
{xi,1,⋯,xi,j−1}时最大化下一个token的生成概率:
max
Θ
mem
P
(
x
i
,
j
∣
x
1
,
1
m
,
⋯
,
x
i
−
1
,
k
i
−
1
m
,
x
i
,
1
,
⋯
,
x
i
,
j
−
1
)
\max_{\Theta_{\text{mem}}} \mathcal{P}(x_{i,j} \ | \ x^m_{1,1}, \cdots, x^m_{i-1, k_{i-1}}, x_{i,1}, \cdots, x_{i, j-1} )
ΘmemmaxP(xi,j ∣ x1,1m,⋯,xi−1,ki−1m,xi,1,⋯,xi,j−1)
Memorag的流程
-
先用训练好的记忆模型 Θ mem \Theta_{\text{mem}} Θmem对数据集进行记忆生成全局记忆 X m \boldsymbol{\mathcal{X}}^m Xm。(Memorag基于Qwen2-7B-Instruct 和 Mistral-7B-Instruct-v0.2训练并开源了记忆模型memorag-qwen2-7b-inst和memorag-mistral-7b-inst)
-
对于新的任务,让记忆模型使用全局记忆 X m \boldsymbol{\mathcal{X}}^m Xm生成任务相关的线索y。
-
使用检索器根据线索y从数据集中检索出跟任务相关的文本。
-
用生成模型根据检索出来的证据文本生成最终的答案 Y \mathcal{Y} Y: Y = Θ gen ( X ^ , q ∣ θ ) \mathcal{Y} = \Theta_{\text{gen}}(\hat{\mathcal{X}}, q|\theta) Y=Θgen(X^,q∣θ)。
下面的图片是文中列举的几个MemoRAG的应用场景案例。
总结
- 跟HyDE的思路有类似之处,只是相比于HyDE多了模型微调和对查询数据集的记忆。在论文的试验部分,HyDE也是比较的baseline,Memorag相比于HyDE性能有提升,但这个实验也能说明在不能使用Memorag时,可考虑优先使用HyDE,因为它简单且效果不错。
- Memorag要先去记忆数据集,所以感觉适用主要为数据不太多的场景,不然数据集太大,超过模型的上下文窗口就无法处理了。不过论文提到应用token压缩技术(token compression technique),可以使记忆模型的处理的上下文窗口长度扩大2-16倍。