informer之Proof of Lemma 1(引理1证明)M(q_i, K)

我们用详细的数学推导和结论证明了该不等式的正确性,我们将其分为两个部分进行讨论:不等式的左部分和右部分。
在这里插入图片描述
在这里插入图片描述

证明不等式:

ln ⁡ L K ≤ M ( q i , K ) ≤ max ⁡ j { q i k j T d } − 1 L K ∑ j = 1 L K { q i k j T d } + ln ⁡ L K \ln LK \leq M(q_i, K) \leq \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} - \frac{1}{LK} \sum_{j=1}^{LK} \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} + \ln LK lnLKM(qi,K)maxj{d qikjT}LK1j=1LK{d qikjT}+lnLK

其中, M ( q i , K ) M(q_i, K) M(qi,K)定义为:

M ( q i , K ) = ln ⁡ ( ∑ j = 1 L K exp ⁡ ( q i k j T d ) ) − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) = \ln \left( \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \right) - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)=ln(j=1LKexp(d qikjT))LK1j=1LK(d qikjT)

左部分的解释

首先,我们看不等式的左部分。对于每一个查询向量 q i q_i qi,第一个项 M ( q i , K ) M(q_i, K) M(qi,K)是一个对数和指数函数,计算固定查询 q i q_i qi和所有键的内积。我们定义:
f i ( K ) = ln ⁡ ∑ j = 1 L K exp ⁡ ( q i k j T d ) f_i(K) = \ln \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) fi(K)=lnj=1LKexp(d qikjT)

根据Log-sum-exp网络的相关理论(Calafiore, Gaubert和Possieri 2018)以及进一步的分析,函数 f i ( K ) f_i(K) fi(K)是一个凸函数。此外, f i ( K ) f_i(K) fi(K)加上线性组合 k j k_j kj使得 M ( q i , K ) M(q_i, K) M(qi,K)成为固定查询情况下的凸函数。

然后,我们对单个向量 k j k_j kj求导数:
∂ M ( q i , K ) ∂ k j = exp ⁡ ( q i k j T d ) ∑ j = 1 L K exp ⁡ ( q i k j T d ) ⋅ q i d − 1 L K ⋅ q i d \frac{\partial M(q_i, K)}{\partial k_j} = \frac{\exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)} \cdot \frac{q_i}{\sqrt{d}} - \frac{1}{LK} \cdot \frac{q_i}{\sqrt{d}} kjM(qi,K)=j=1LKexp(d qikjT)exp(d qikjT)d qiLK1d qi

我们再将其整理为:
∂ M ( q i , K ) ∂ k j = ( exp ⁡ ( q i k j T d ) ∑ j = 1 L K exp ⁡ ( q i k j T d ) − 1 L K ) ⋅ q i d \frac{\partial M(q_i, K)}{\partial k_j} = \left( \frac{\exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)} - \frac{1}{LK} \right) \cdot \frac{q_i}{\sqrt{d}} kjM(qi,K)= j=1LKexp(d qikjT)exp(d qikjT)LK1 d qi

最小值条件

在这里插入图片描述

为了找到该函数的最小值,我们需要让所有的导数为零,也就是令梯度为零:即:
∂ M ( q i , K ) ∂ k j = 0 \frac{\partial M(q_i, K)}{\partial k_j} = 0 kjM(qi,K)=0

这要求:
exp ⁡ ( q i k j T d ) ∑ j = 1 L K exp ⁡ ( q i k j T d ) = 1 L K \frac{\exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)} = \frac{1}{LK} j=1LKexp(d qikjT)exp(d qikjT)=LK1

这意味着:
exp ⁡ ( q i k j T d ) = ∑ j = 1 L K exp ⁡ ( q i k j T d ) L K \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) = \frac{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{LK} exp(d qikjT)=LKj=1LKexp(d qikjT)

对数化后,我们有:
q i k j T d = ln ⁡ ( ∑ j = 1 L K exp ⁡ ( q i k j T d ) L K ) \frac{q_i k_j^T}{\sqrt{d}} = \ln \left( \frac{\sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right)}{LK} \right) d qikjT=ln LKj=1LKexp(d qikjT)

进一步化简:
q i k j T + ln ⁡ L K = ln ⁡ ∑ j = 1 L K exp ⁡ ( q i k j T ) q_i k_j^T + \ln LK = \ln \sum_{j=1}^{LK} \exp \left( {q_i k_j^T}{} \right) qikjT+lnLK=lnj=1LKexp(qikjT)

由于所有的 k j k_j kj对应的值是相同的,我们得到:
k 1 = k 2 = ⋯ = k L K k_1 = k_2 = \cdots = k_{LK} k1=k2==kLK

此时,最小值为:
M ( q i , K ) = ln ⁡ L K M(q_i, K) = \ln LK M(qi,K)=lnLK

自然地,这需要 k 1 = k 2 = ⋯ = k L K k_1 = k_2 = \cdots = k_{LK} k1=k2==kLK,我们有测量的最小值为 ln ⁡ L K \ln LK lnLK,即:
M ( q i , K ) ≥ ln ⁡ L K M(q_i, K) \geq \ln LK M(qi,K)lnLK

证明不等式的右半部分

我们要证明不等式的右半部分,使用如下定义的 M ( q i , K ) M(q_i, K) M(qi,K)

M ( q i , K ) = ln ⁡ ( ∑ j = 1 L K exp ⁡ ( q i k j T d ) ) − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) = \ln \left( \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \right) - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)=ln(j=1LKexp(d qikjT))LK1j=1LK(d qikjT)

右半部分的证明步骤

根据图中的方法,我们逐步进行证明:

  1. 定义最大内积:
    α = max ⁡ j { q i k j T d } \alpha = \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} α=maxj{d qikjT}

  2. 上界估计:
    ∑ j = 1 L K exp ⁡ ( q i k j T d ) ≤ ∑ j = 1 L K exp ⁡ ( α ) = L K ⋅ exp ⁡ ( α ) \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \leq \sum_{j=1}^{LK} \exp(\alpha) = LK \cdot \exp(\alpha) j=1LKexp(d qikjT)j=1LKexp(α)=LKexp(α)

  3. 对数运算:
    ln ⁡ ( ∑ j = 1 L K exp ⁡ ( q i k j T d ) ) ≤ ln ⁡ ( L K ⋅ exp ⁡ ( α ) ) = ln ⁡ L K + α \ln \left( \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \right) \leq \ln \left( LK \cdot \exp(\alpha) \right) = \ln LK + \alpha ln(j=1LKexp(d qikjT))ln(LKexp(α))=lnLK+α

  4. 结合 M ( q i , K ) M(q_i, K) M(qi,K)的定义:

    我们有:
    M ( q i , K ) = ln ⁡ ( ∑ j = 1 L K exp ⁡ ( q i k j T d ) ) − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) = \ln \left( \sum_{j=1}^{LK} \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) \right) - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)=ln(j=1LKexp(d qikjT))LK1j=1LK(d qikjT)

    使用第3步的结果:
    M ( q i , K ) ≤ ln ⁡ L K + α − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) \leq \ln LK + \alpha - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)lnLK+αLK1j=1LK(d qikjT)

    由于 α \alpha α是定义的最大值:
    α = max ⁡ j { q i k j T d } \alpha = \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} α=maxj{d qikjT}

  5. 最终表达式:
    α \alpha α替换回表达式中,得到:
    M ( q i , K ) ≤ ln ⁡ L K + max ⁡ j { q i k j T d } − 1 L K ∑ j = 1 L K ( q i k j T d ) M(q_i, K) \leq \ln LK + \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) M(qi,K)lnLK+maxj{d qikjT}LK1j=1LK(d qikjT)

最终结论

综上所述,我们已经证明了不等式的右半部分:
M ( q i , K ) ≤ max ⁡ j { q i k j T d } − 1 L K ∑ j = 1 L K ( q i k j T d ) + ln ⁡ L K M(q_i, K) \leq \max_j \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} - \frac{1}{LK} \sum_{j=1}^{LK} \left( \frac{q_i k_j^T}{\sqrt{d}} \right) + \ln LK M(qi,K)maxj{d qikjT}LK1j=1LK(d qikjT)+lnLK

因此,引理1的右半部分不等式得证。

举例说明

假设有一个查询向量 q i q_i qi和一个键集合 K = { k 1 , k 2 , … , k L K } K = \{k_1, k_2, \ldots, k_{LK}\} K={k1,k2,,kLK}。假设 d = 1 d = 1 d=1并且每个键向量 k j k_j kj都是相同的向量 k k k。此时, q i q_i qi和每个键的内积相同,即 q i k T q_i k^T qikT。根据以上推导:

  • 对于左部分,因为所有 k j k_j kj都相同,所以 exp ⁡ ( q i k j T d ) \exp \left( \frac{q_i k_j^T}{\sqrt{d}} \right) exp(d qikjT)的和为 L K ⋅ exp ⁡ ( q i k T d ) LK \cdot \exp \left( \frac{q_i k^T}{\sqrt{d}} \right) LKexp(d qikT),对数之后得到 ln ⁡ L K + q i k T d \ln LK + \frac{q_i k^T}{\sqrt{d}} lnLK+d qikT
  • 对于右部分,选择最大的内积(其实就是 q i k T q_i k^T qikT),所以不等式右侧为 ln ⁡ L K + max ⁡ { q i k j T d } \ln LK + \max \left\{ \frac{q_i k_j^T}{\sqrt{d}} \right\} lnLK+max{d qikjT},因为所有的内积都相同,所以最大值也是 q i k T d \frac{q_i k^T}{\sqrt{d}} d qikT

这样,通过具体的例子说明,不等式的左右两部分在不同情况下如何达到平衡,从而证明了不等式的正确性。

  • 27
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值