Self-attention中为什么softmax要除d_k

我觉得这是一个很有意思的问题,简单但是很细节。先说结论,是为了保证梯度的平稳。那怎么个意思?

首先说向量(行向量和列向量都一样),他们的点乘和叉乘。

向量的内积:也叫点乘,结果是一个数。两个向量对应位相乘再求和。要求向量a和b的维度要一样。
a ⃗ ∗ b ⃗ = ( a 1 ∗ b 1 + a 2 ∗ b 2 + ⋯ + a n ∗ b n ) \vec{a}*\vec{b}=(a_1*b_1+a_2*b_2+\cdots+a_n*b_n) a b =(a1b1+a2b2++anbn)
内积的几何意义:计算两个向量之间的夹角或者向量b在向量a上的投影。
a ⃗ ∗ b ⃗ = ∣ a ∣ ∗ ∣ b ∣ ∗ c o s ( θ ) ( 余 弦 定 理 可 以 证 ) \vec{a}*\vec{b}=|a|*|b|*cos(\theta)(余弦定理可以证) a b =abcos(θ)()

θ = a r c c o s ( a ⃗ ∗ b ⃗ ∣ a ∣ ∣ b ∣ ) \theta=arccos(\frac{\vec{a}*\vec{b}}{|a||b|}) θ=arccos(aba b )
向量的外积:也叫叉乘,结果是一个新的向量。具体的来说它是a向量和b向量组成平面的法向量。
a ⃗ x b ⃗ = ∣ i j k x a y a z a x b y b z b ∣ \vec{a}x\vec{b}=\begin{vmatrix} i&j&k\\ x_a&y_a&z_a\\ x_b&y_b&z_b \end{vmatrix} a xb =ixaxbjyaybkzazb
下边进入正文Self-attention的细节。
S e l f a t t e n t i o n = s o f t m a x ( Q K T d k ) V Self attention=softmax(\frac{QK^T}{\sqrt{d_k}})V Selfattention=softmax(dk QKT)V
读过《Attention is all you need》我们就知道QKV三个矩阵都是X的线性变换。这里为了简单,我们认为QKV都是同样的一个矩阵,也就是Q=K=V。假设Q是一个行向量,维度为 d k d_k dk。那么我们可以知道 Q Q T QQ^T QQT其实在计算每个元素之间的相似度。而且不存在上下文关系,也就是全局的相似性关系。那么继续我们如果假设Q是从一个标准正太分布(0均值,1方差的高斯分布)中产生的,那么 Q Q T QQ^T QQT就也是0均值, 2 d k 2d_k 2dk为方差的。为什么?

Q Q T QQ^T QQT相乘之后我理解现在是一个卡方分布,不知道这里理解的对不对,希望和大家一起探讨。那么卡方分布的方差 E ( x 2 ) = 2 d k E(x^2)= 2d_k E(x2)=2dk

所以为了让 Q Q T QQ^T QQT的方差回到1,我们需要除 d k \sqrt{d_k} dk 。又来了,为什么想要让方差回到1呢?

因为方差大, Q Q T QQ^T QQT中出现大值的可能性就大,下边引用我看到的文章的一段话。

d k d_k dk很大时,意味着 Q Q T QQ^T QQT的方差就很大,分布会趋于陡峭(分布的方差大,分布就会集中在绝对值大的区域),就会使得softmax()之后使得值出现两极分化的状态。(https://blog.csdn.net/qq_44846512/article/details/114364559)

也就是说方差大,那么经过softmax后输出的矩阵softmax( Q Q T QQ^T QQT),会很陡峭。这句话乍一看可能比较含糊,我后来自己特意看了下结果,就明白了。

import torch
import matplotlib.pyplot as plt
import math


def main():
    matsize = 10

    q = torch.randn(matsize)
    k = torch.randn(matsize*matsize)
    v = torch.randn(matsize*matsize*matsize)
    c1 = q*q
    c2 = k*k
    c3 = v*v
    for var in q, k, v, c1, c2, c3:
        print("mean is %f, div is %f." % (var.mean(), var.var()))

    ax1 = plt.subplot(331)  # c1 origin
    plt.plot(torch.arange(c1.shape[0]), c1)
    ax2 = plt.subplot(334)  # c1 softmax
    plt.plot(torch.arange(c1.shape[0]), torch.nn.functional.softmax(c1))
    ax3 = plt.subplot(332)  # c2 origin
    plt.plot(torch.arange(c2.shape[0]), c2)
    ax4 = plt.subplot(335)  # c2 softmax
    plt.plot(torch.arange(c2.shape[0]), torch.nn.functional.softmax(c2))
    ax6 = plt.subplot(338)  # c2 softmax with sqrt dk
    plt.plot(torch.arange(c2.shape[0]),
             torch.nn.functional.softmax(c2/math.sqrt(c2.shape[0])))
    ax7 = plt.subplot(333)  # c3 origin
    plt.plot(torch.arange(c3.shape[0]), c3)
    ax8 = plt.subplot(336)  # c3 softmax
    plt.plot(torch.arange(c3.shape[0]), torch.nn.functional.softmax(c3))
    ax9 = plt.subplot(339)  # c3 softmax with sqrt dk
    plt.plot(torch.arange(c3.shape[0]),
             torch.nn.functional.softmax(c3/math.sqrt(c3.shape[0])))

    plt.savefig("cov.jpg")
    plt.show()


if __name__ == "__main__":
    main()

通过下边的这个图

8子图

图中最上边一行是softmax之前的结果,中间一行是没有除 d k d_k dk的softmax结果,最后一行是除了 d k d_k dk的softmax结果。可以看出在不除 d k d_k dk的时候softmax的结果只会在输入的最大值或者几个大值附近出现,看起来非常陡峭。当输入除了 d k d_k dk以后我们发现输入数据的分布大部分都保留了下来,这样的好处就是可以在梯度回传的时候让梯度比较平稳。而且当 d k d_k dk越大,影响越明显(从左向右 d k d_k dk越来越大)。

这就是为什么Self-attention中要除 d k d_k dk。这也是我看了一些网上的资料后自己的理解,只不过我觉得其中有不清楚的地方自己又想了下写下来而已。

参考文献

  1. https://www.zhihu.com/question/293696778 卡方分布方差计算
  2. https://blog.csdn.net/qq_44846512/article/details/114364559 关于这个问题写的也不错的博客
  • 12
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值