我们用详细的数学推导和结论证明了该不等式的正确性,我们将其分为两个部分进行讨论:不等式的左部分和右部分。
证明不等式:
ln L K ≤ M ( q i , K ) ≤ max j { q i k j T d } − 1 L K ∑ j = 1 L K { q i k j T d } + ln L K \ln LK \leq M(q_i, K) \leq \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} - \frac{1}{LK} \sum_{j=1}^{LK} \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} + \ln LK lnLK≤M(qi,K)≤maxj{dqikjT}−LK1∑j=1LK{dqikjT}+lnLK
其中, M ( q i , K ) M(q_i, K) M(qi,K)定义为:
M ( q i , K ) = ln ( ∑ j = 1 L K exp ( q i k j T d ) ) − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) = \ln \left( \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \right) - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)=ln(∑j=1LKexp(dqikjT))−LK1∑j=1LK(dqikjT)
左部分的解释
首先,我们看不等式的左部分。对于每一个查询向量
q
i
q_i
qi,第一个项
M
(
q
i
,
K
)
M(q_i, K)
M(qi,K)是一个对数和指数函数,计算固定查询
q
i
q_i
qi和所有键的内积。我们定义:
f
i
(
K
)
=
ln
∑
j
=
1
L
K
exp
(
q
i
k
j
T
d
)
f_i(K) = \ln \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)
fi(K)=ln∑j=1LKexp(dqikjT)
根据Log-sum-exp网络的相关理论(Calafiore, Gaubert和Possieri 2018)以及进一步的分析,函数 f i ( K ) f_i(K) fi(K)是一个凸函数。此外, f i ( K ) f_i(K) fi(K)加上线性组合 k j k_j kj使得 M ( q i , K ) M(q_i, K) M(qi,K)成为固定查询情况下的凸函数。
然后,我们对单个向量
k
j
k_j
kj求导数:
∂
M
(
q
i
,
K
)
∂
k
j
=
exp
(
q
i
k
j
T
d
)
∑
j
=
1
L
K
exp
(
q
i
k
j
T
d
)
⋅
q
i
d
−
1
L
K
⋅
q
i
d
\frac{\partial M(q_i, K)}{\partial k_j} = \frac{\exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)} \cdot \frac{q_i}{\sqrt{d}} - \frac{1}{LK} \cdot \frac{q_i}{\sqrt{d}}
∂kj∂M(qi,K)=∑j=1LKexp(dqikjT)exp(dqikjT)⋅dqi−LK1⋅dqi
我们再将其整理为:
∂
M
(
q
i
,
K
)
∂
k
j
=
(
exp
(
q
i
k
j
T
d
)
∑
j
=
1
L
K
exp
(
q
i
k
j
T
d
)
−
1
L
K
)
⋅
q
i
d
\frac{\partial M(q_i, K)}{\partial k_j} = \left( \frac{\exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)} - \frac{1}{LK} \right) \cdot \frac{q_i}{\sqrt{d}}
∂kj∂M(qi,K)=
∑j=1LKexp(dqikjT)exp(dqikjT)−LK1
⋅dqi
最小值条件
为了找到该函数的最小值,我们需要让所有的导数为零,也就是令梯度为零:即:
∂
M
(
q
i
,
K
)
∂
k
j
=
0
\frac{\partial M(q_i, K)}{\partial k_j} = 0
∂kj∂M(qi,K)=0
这要求:
exp
(
q
i
k
j
T
d
)
∑
j
=
1
L
K
exp
(
q
i
k
j
T
d
)
=
1
L
K
\frac{\exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)} = \frac{1}{LK}
∑j=1LKexp(dqikjT)exp(dqikjT)=LK1
这意味着:
exp
(
q
i
k
j
T
d
)
=
∑
j
=
1
L
K
exp
(
q
i
k
j
T
d
)
L
K
\exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) = \frac{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{LK}
exp(dqikjT)=LK∑j=1LKexp(dqikjT)
对数化后,我们有:
q
i
k
j
T
d
=
ln
(
∑
j
=
1
L
K
exp
(
q
i
k
j
T
d
)
L
K
)
\frac{q_i k_j^T}{\sqrt{d}} = \ln \left( \frac{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{LK} \right)
dqikjT=ln
LK∑j=1LKexp(dqikjT)
进一步化简:
q
i
k
j
T
+
ln
L
K
=
ln
∑
j
=
1
L
K
exp
(
q
i
k
j
T
)
q_i k_j^T + \ln LK = \ln \sum_{j=1}^{LK} \exp \left( {q_i k_j^T}{} \right)
qikjT+lnLK=ln∑j=1LKexp(qikjT)
由于所有的
k
j
k_j
kj对应的值是相同的,我们得到:
k
1
=
k
2
=
⋯
=
k
L
K
k_1 = k_2 = \cdots = k_{LK}
k1=k2=⋯=kLK
此时,最小值为:
M
(
q
i
,
K
)
=
ln
L
K
M(q_i, K) = \ln LK
M(qi,K)=lnLK
自然地,这需要
k
1
=
k
2
=
⋯
=
k
L
K
k_1 = k_2 = \cdots = k_{LK}
k1=k2=⋯=kLK,我们有测量的最小值为
ln
L
K
\ln LK
lnLK,即:
M
(
q
i
,
K
)
≥
ln
L
K
M(q_i, K) \geq \ln LK
M(qi,K)≥lnLK
证明不等式的右半部分
我们要证明不等式的右半部分,使用如下定义的 M ( q i , K ) M(q_i, K) M(qi,K):
M ( q i , K ) = ln ( ∑ j = 1 L K exp ( q i k j T d ) ) − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) = \ln \left( \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \right) - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)=ln(∑j=1LKexp(dqikjT))−LK1∑j=1LK(dqikjT)
右半部分的证明步骤
根据图中的方法,我们逐步进行证明:
-
定义最大内积:
α = max j { q i k j T d } \alpha = \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} α=maxj{dqikjT} -
上界估计:
∑ j = 1 L K exp ( q i k j T d ) ≤ ∑ j = 1 L K exp ( α ) = L K ⋅ exp ( α ) \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \leq \sum_{j=1}^{LK} \exp(\alpha) = LK \cdot \exp(\alpha) ∑j=1LKexp(dqikjT)≤∑j=1LKexp(α)=LK⋅exp(α) -
对数运算:
ln ( ∑ j = 1 L K exp ( q i k j T d ) ) ≤ ln ( L K ⋅ exp ( α ) ) = ln L K + α \ln \left( \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \right) \leq \ln \left( LK \cdot \exp(\alpha) \right) = \ln LK + \alpha ln(∑j=1LKexp(dqikjT))≤ln(LK⋅exp(α))=lnLK+α -
结合 M ( q i , K ) M(q_i, K) M(qi,K)的定义:
我们有:
M ( q i , K ) = ln ( ∑ j = 1 L K exp ( q i k j T d ) ) − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) = \ln \left( \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \right) - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)=ln(∑j=1LKexp(dqikjT))−LK1∑j=1LK(dqikjT)使用第3步的结果:
M ( q i , K ) ≤ ln L K + α − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) \leq \ln LK + \alpha - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)≤lnLK+α−LK1∑j=1LK(dqikjT)由于 α \alpha α是定义的最大值:
α = max j { q i k j T d } \alpha = \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} α=maxj{dqikjT} -
最终表达式:
将 α \alpha α替换回表达式中,得到:
M ( q i , K ) ≤ ln L K + max j { q i k j T d } − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) \leq \ln LK + \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)≤lnLK+maxj{dqikjT}−LK1∑j=1LK(dqikjT)
最终结论
综上所述,我们已经证明了不等式的右半部分:
M
(
q
i
,
K
)
≤
max
j
{
q
i
k
j
T
d
}
−
1
L
K
∑
j
=
1
L
K
(
q
i
k
j
T
d
)
+
ln
L
K
M(q_i, K) \leq \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) + \ln LK
M(qi,K)≤maxj{dqikjT}−LK1∑j=1LK(dqikjT)+lnLK
因此,引理1的右半部分不等式得证。
举例说明
假设有一个查询向量 q i q_i qi和一个键集合 K = { k 1 , k 2 , … , k L K } K = \{k_1, k_2, \ldots, k_{LK}\} K={k1,k2,…,kLK}。假设 d = 1 d = 1 d=1并且每个键向量 k j k_j kj都是相同的向量 k k k。此时, q i q_i qi和每个键的内积相同,即 q i k T q_i k^T qikT。根据以上推导:
- 对于左部分,因为所有 k j k_j kj都相同,所以 exp ( q i k j T d ) \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) exp(dqikjT)的和为 L K ⋅ exp ( q i k T d ) LK \cdot \exp \left( \frac{q_i k^T}{\sqrt{d}} \right) LK⋅exp(dqikT),对数之后得到 ln L K + q i k T d \ln LK + \frac{q_i k^T}{\sqrt{d}} lnLK+dqikT。
- 对于右部分,选择最大的内积(其实就是 q i k T q_i k^T qikT),所以不等式右侧为 ln L K + max { q i k j T d } \ln LK + \max \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} lnLK+max{dqikjT},因为所有的内积都相同,所以最大值也是 q i k T d \frac{q_i k^T}{\sqrt{d}} dqikT。
这样,通过具体的例子说明,不等式的左右两部分在不同情况下如何达到平衡,从而证明了不等式的正确性。