我觉得这是一个很有意思的问题,简单但是很细节。先说结论,是为了保证梯度的平稳。那怎么个意思?
首先说向量(行向量和列向量都一样),他们的点乘和叉乘。
向量的内积:也叫点乘,结果是一个数。两个向量对应位相乘再求和。要求向量a和b的维度要一样。
a
⃗
∗
b
⃗
=
(
a
1
∗
b
1
+
a
2
∗
b
2
+
⋯
+
a
n
∗
b
n
)
\vec{a}*\vec{b}=(a_1*b_1+a_2*b_2+\cdots+a_n*b_n)
a∗b=(a1∗b1+a2∗b2+⋯+an∗bn)
内积的几何意义:计算两个向量之间的夹角或者向量b在向量a上的投影。
a
⃗
∗
b
⃗
=
∣
a
∣
∗
∣
b
∣
∗
c
o
s
(
θ
)
(
余
弦
定
理
可
以
证
)
\vec{a}*\vec{b}=|a|*|b|*cos(\theta)(余弦定理可以证)
a∗b=∣a∣∗∣b∣∗cos(θ)(余弦定理可以证)
θ
=
a
r
c
c
o
s
(
a
⃗
∗
b
⃗
∣
a
∣
∣
b
∣
)
\theta=arccos(\frac{\vec{a}*\vec{b}}{|a||b|})
θ=arccos(∣a∣∣b∣a∗b)
向量的外积:也叫叉乘,结果是一个新的向量。具体的来说它是a向量和b向量组成平面的法向量。
a
⃗
x
b
⃗
=
∣
i
j
k
x
a
y
a
z
a
x
b
y
b
z
b
∣
\vec{a}x\vec{b}=\begin{vmatrix} i&j&k\\ x_a&y_a&z_a\\ x_b&y_b&z_b \end{vmatrix}
axb=∣∣∣∣∣∣ixaxbjyaybkzazb∣∣∣∣∣∣
下边进入正文Self-attention的细节。
S
e
l
f
a
t
t
e
n
t
i
o
n
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
Self attention=softmax(\frac{QK^T}{\sqrt{d_k}})V
Selfattention=softmax(dkQKT)V
读过《Attention is all you need》我们就知道QKV三个矩阵都是X的线性变换。这里为了简单,我们认为QKV都是同样的一个矩阵,也就是Q=K=V。假设Q是一个行向量,维度为
d
k
d_k
dk。那么我们可以知道
Q
Q
T
QQ^T
QQT其实在计算每个元素之间的相似度。而且不存在上下文关系,也就是全局的相似性关系。那么继续我们如果假设Q是从一个标准正太分布(0均值,1方差的高斯分布)中产生的,那么
Q
Q
T
QQ^T
QQT就也是0均值,
2
d
k
2d_k
2dk为方差的。为什么?
因 Q Q T QQ^T QQT相乘之后我理解现在是一个卡方分布,不知道这里理解的对不对,希望和大家一起探讨。那么卡方分布的方差 E ( x 2 ) = 2 d k E(x^2)= 2d_k E(x2)=2dk。
所以为了让 Q Q T QQ^T QQT的方差回到1,我们需要除 d k \sqrt{d_k} dk。又来了,为什么想要让方差回到1呢?
因为方差大, Q Q T QQ^T QQT中出现大值的可能性就大,下边引用我看到的文章的一段话。
当 d k d_k dk很大时,意味着 Q Q T QQ^T QQT的方差就很大,分布会趋于陡峭(分布的方差大,分布就会集中在绝对值大的区域),就会使得softmax()之后使得值出现两极分化的状态。(https://blog.csdn.net/qq_44846512/article/details/114364559)
也就是说方差大,那么经过softmax后输出的矩阵softmax( Q Q T QQ^T QQT),会很陡峭。这句话乍一看可能比较含糊,我后来自己特意看了下结果,就明白了。
import torch
import matplotlib.pyplot as plt
import math
def main():
matsize = 10
q = torch.randn(matsize)
k = torch.randn(matsize*matsize)
v = torch.randn(matsize*matsize*matsize)
c1 = q*q
c2 = k*k
c3 = v*v
for var in q, k, v, c1, c2, c3:
print("mean is %f, div is %f." % (var.mean(), var.var()))
ax1 = plt.subplot(331) # c1 origin
plt.plot(torch.arange(c1.shape[0]), c1)
ax2 = plt.subplot(334) # c1 softmax
plt.plot(torch.arange(c1.shape[0]), torch.nn.functional.softmax(c1))
ax3 = plt.subplot(332) # c2 origin
plt.plot(torch.arange(c2.shape[0]), c2)
ax4 = plt.subplot(335) # c2 softmax
plt.plot(torch.arange(c2.shape[0]), torch.nn.functional.softmax(c2))
ax6 = plt.subplot(338) # c2 softmax with sqrt dk
plt.plot(torch.arange(c2.shape[0]),
torch.nn.functional.softmax(c2/math.sqrt(c2.shape[0])))
ax7 = plt.subplot(333) # c3 origin
plt.plot(torch.arange(c3.shape[0]), c3)
ax8 = plt.subplot(336) # c3 softmax
plt.plot(torch.arange(c3.shape[0]), torch.nn.functional.softmax(c3))
ax9 = plt.subplot(339) # c3 softmax with sqrt dk
plt.plot(torch.arange(c3.shape[0]),
torch.nn.functional.softmax(c3/math.sqrt(c3.shape[0])))
plt.savefig("cov.jpg")
plt.show()
if __name__ == "__main__":
main()
通过下边的这个图
图中最上边一行是softmax之前的结果,中间一行是没有除 d k d_k dk的softmax结果,最后一行是除了 d k d_k dk的softmax结果。可以看出在不除 d k d_k dk的时候softmax的结果只会在输入的最大值或者几个大值附近出现,看起来非常陡峭。当输入除了 d k d_k dk以后我们发现输入数据的分布大部分都保留了下来,这样的好处就是可以在梯度回传的时候让梯度比较平稳。而且当 d k d_k dk越大,影响越明显(从左向右 d k d_k dk越来越大)。
这就是为什么Self-attention中要除 d k d_k dk。这也是我看了一些网上的资料后自己的理解,只不过我觉得其中有不清楚的地方自己又想了下写下来而已。
参考文献
- https://www.zhihu.com/question/293696778 卡方分布方差计算
- https://blog.csdn.net/qq_44846512/article/details/114364559 关于这个问题写的也不错的博客