原文链接:transformer中的attention为什么scaled?——LinT的回答
这个问题困扰良久,一直没研究清楚,只知道个大概,不知其所以然,这里专门开一篇总结一下。由于有人珠玉在前,写得极其精彩,所以直接转载了,以下为原文。
———————————————————————————————————————————————
谢邀。非常有意义的问题,我思考了好久,按照描述中的两个问题分点回答一下。
1. 为什么比较大的输入会使得softmax的梯度变得很小?
对于一个输入向量 x ∈ R d \mathbf{x} \in \mathbb{R}^{d} x∈Rd ,softmax函数将其映射/归一化到一个分布 y ^ ∈ R d \hat{\mathbf{y}} \in \mathbb{R}^{d} y^∈Rd。在这个过程中,softmax先用一个自然底数 e e e将输入中的元素间差距先“拉大”,然后归一化为一个分布。假设某个输入 x x x中最大的的元素下标是 k k k,如果输入的数量级变大(每个元素都很大),那么 y ^ k \hat{y}_{k} y^k会非常接近1。
我们可以用一个小例子来看看 x x x的数量级对输入最大元素对应的预测概率 y ^ k \hat{y}_{k} y^k的影响。假定输入 x = [ a , a , 2 a ] ⊤ \mathbf{x}=[a, a, 2 a]^{\top} x=[a,a,2a]⊤),我们来看不同量级的 a a a产生的 y ^ 3 \hat{y}_{3} y^3有什么区别。
- a = 1 a=1 a=1时, y ^ 3 = 0.5761168847658291 \hat{y}_{3}=0.5761168847658291 y^3=0.5761168847658291;
- a = 10 a=10 a=10时, y ^ 3 = 0.999909208384341 \hat{y}_{3}=0.999909208384341 y^3=0.999909208384341;
- a = 100 a=100 a=100时, y ^ 3 ≈ 1.0 \hat{y}_{3} \approx 1.0 y^3≈1.0(计算精度限制);
我们不妨把 a a a在不同取值下,对应的 y ^ 3 \hat{y}_{3} y^3全部绘制出来。代码如下:
from math import exp
from matplotlib import pyplot as plt
import numpy as np
f = lambda x: exp(x * 2) / (exp(x) + exp(x) + exp(x * 2))
x = np.linspace(0, 100, 100)
y_3 = [f(x_i) for x_i in x]
plt.plot(x, y_3)
plt.show()
得到的图如下所示:
可以看到,数量级对softmax得到的分布影响非常大。在数量级较大时,softmax将几乎全部的概率分布都分配给了最大值对应的标签。
然后我们来看softmax的梯度。不妨简记softmax函数为
g
(
⋅
)
g(\cdot)
g(⋅),softmax得到的分布向量
y
^
=
g
(
x
)
\hat{\mathbf{y}}=g(\mathbf{x})
y^=g(x)对输入
x
x
x的梯度为:
∂
g
(
x
)
∂
x
=
diag
(
y
^
)
−
y
^
y
^
⊤
∈
R
d
×
d
\frac{\partial g(\mathbf{x})}{\partial \mathbf{x}}=\operatorname{diag}(\hat{\mathbf{y}})-\hat{\mathbf{y}} \hat{\mathbf{y}}^{\top} \quad \in \mathbb{R}^{d \times d}
∂x∂g(x)=diag(y^)−y^y^⊤∈Rd×d 把这个矩阵展开:
∂
g
(
x
)
∂
x
=
[
y
^
1
0
⋯
0
0
y
^
2
⋯
0
⋮
⋮
⋱
⋮
0
0
⋯
y
^
d
]
−
[
y
^
1
2
y
^
1
y
^
2
⋯
y
^
1
y
^
d
y
^
2
y
^
1
y
^
2
2
⋯
y
^
2
y
^
d
⋮
⋮
⋱
⋮
y
^
d
y
^
1
y
^
d
y
^
2
⋯
y
^
d
2
]
\frac{\partial g(\mathbf{x})}{\partial \mathbf{x}}=\left[\begin{array}{cccc} \hat{y}_{1} & 0 & \cdots & 0 \\ 0 & \hat{y}_{2} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \hat{y}_{d} \end{array}\right]-\left[\begin{array}{cccc} \hat{y}_{1}^{2} & \hat{y}_{1} \hat{y}_{2} & \cdots & \hat{y}_{1} \hat{y}_{d} \\ \hat{y}_{2} \hat{y}_{1} & \hat{y}_{2}^{2} & \cdots & \hat{y}_{2} \hat{y}_{d} \\ \vdots & \vdots & \ddots & \vdots \\ \hat{y}_{d} \hat{y}_{1} & \hat{y}_{d} \hat{y}_{2} & \cdots & \hat{y}_{d}^{2} \end{array}\right]
∂x∂g(x)=⎣
⎡y^10⋮00y^2⋮0⋯⋯⋱⋯00⋮y^d⎦
⎤−⎣
⎡y^12y^2y^1⋮y^dy^1y^1y^2y^22⋮y^dy^2⋯⋯⋱⋯y^1y^dy^2y^d⋮y^d2⎦
⎤ 根据前面的讨论,当输入
x
x
x的元素均较大时,softmax会把大部分概率分布分配给最大的元素,假设我们的输入数量级很大,最大的元素是
x
1
x_1
x1,那么就将产生一个接近one-hot的向量
y
^
≈
[
1
,
0
,
⋯
,
0
]
⊤
\hat{\mathbf{y}} \approx[1,0, \cdots, 0]^{\top}
y^≈[1,0,⋯,0]⊤,此时上面的矩阵变为如下形式:
∂
g
(
x
)
∂
x
≈
[
1
0
⋯
0
0
0
⋯
0
⋮
⋮
⋱
⋮
0
0
⋯
0
]
−
[
1
0
⋯
0
0
0
⋯
0
⋮
⋮
⋱
⋮
0
0
⋯
0
]
=
0
\frac{\partial g(\mathbf{x})}{\partial \mathbf{x}} \approx\left[\begin{array}{cccc} 1 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 0 \end{array}\right]-\left[\begin{array}{cccc} 1 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 0 \end{array}\right]=\mathbf{0}
∂x∂g(x)≈⎣
⎡10⋮000⋮0⋯⋯⋱⋯00⋮0⎦
⎤−⎣
⎡10⋮000⋮0⋯⋯⋱⋯00⋮0⎦
⎤=0 也就是说,在输入的数量级很大时,梯度消失为0,造成参数更新困难。
注: softmax的梯度可以自行推导,网络上也有很多推导可以参考。
2. 维度与点积大小的关系是怎么样的,为什么使用维度的根号来放缩?
针对为什么维度会影响点积的大小,在论文的脚注中其实给出了一点解释:
假设向量
q
q
q和
k
k
k的各个分量是互相独立的随机变量,均值是0,方差是1,那么点积
q
⋅
k
q \cdot k
q⋅k的均值是0,方差是
d
k
d_k
dk。这里我给出一点更详细的推导:
对 ∀ i = 1 , ⋯ , d k \forall i=1, \cdots, d_{k} ∀i=1,⋯,dk, q i q_i qi和 k i k_i ki都是随机变量,为了方便书写,不妨记 X = q i X=q_i X=qi, Y = k i Y=k_i Y=ki。这样有: D ( X ) = D ( Y ) = 1 D(X)=D(Y)=1 D(X)=D(Y)=1, E ( X ) = E ( Y ) = 0 E(X)=E(Y)=0 E(X)=E(Y)=0。则:
- E ( X Y ) = E ( X ) E ( Y ) = 0 × 0 = 0 E(X Y)=E(X) E(Y)=0 \times 0=0 E(XY)=E(X)E(Y)=0×0=0
- D ( X Y ) = E ( X 2 ⋅ Y 2 ) − [ E ( X Y ) ] 2 = E ( X 2 ) E ( Y 2 ) − [ E ( X ) E ( Y ) ] 2 = E ( X 2 − 0 2 ) E ( Y 2 − 0 2 ) − [ E ( X ) E ( Y ) ] 2 = E ( X 2 − [ E ( X ) ] 2 ) E ( Y 2 − [ E ( Y ) ] 2 ) − [ E ( X ) E ( Y ) ] 2 = D ( X ) D ( Y ) − [ E ( X ) E ( Y ) ] 2 = 1 × 1 − ( 0 × 0 ) 2 = 1 \begin{aligned} D(X Y) &=E\left(X^{2} \cdot Y^{2}\right)-[E(X Y)]^{2} \\ &=E\left(X^{2}\right) E\left(Y^{2}\right)-[E(X) E(Y)]^{2} \\ &=E\left(X^{2}-0^{2}\right) E\left(Y^{2}-0^{2}\right)-[E(X) E(Y)]^{2} \\ &=E\left(X^{2}-[E(X)]^{2}\right) E\left(Y^{2}-[E(Y)]^{2}\right)-[E(X) E(Y)]^{2} \\ &=D(X) D(Y)-[E(X) E(Y)]^{2} \\ &=1 \times 1-(0 \times 0)^{2} \\ &=1 \end{aligned} D(XY)=E(X2⋅Y2)−[E(XY)]2=E(X2)E(Y2)−[E(X)E(Y)]2=E(X2−02)E(Y2−02)−[E(X)E(Y)]2=E(X2−[E(X)]2)E(Y2−[E(Y)]2)−[E(X)E(Y)]2=D(X)D(Y)−[E(X)E(Y)]2=1×1−(0×0)2=1
这样
∀
i
=
1
,
⋯
,
d
k
\forall i=1, \cdots, d_{k}
∀i=1,⋯,dk,
q
i
⋅
k
i
q_i \cdot k_i
qi⋅ki的均值是0,方差是1,又由期望和方差的性质, 对相互独立的分量
z
i
z_i
zi,有
E
(
∑
i
Z
i
)
=
∑
i
E
(
Z
i
)
E\left(\sum_{i} Z_{i}\right)=\sum_{i} E\left(Z_{i}\right)
E(i∑Zi)=i∑E(Zi)
以及
D
(
∑
i
Z
i
)
=
∑
i
D
(
Z
i
)
D\left(\sum_{i} Z_{i}\right)=\sum_{i} D\left(Z_{i}\right)
D(i∑Zi)=i∑D(Zi)
所以有
q
⋅
k
q \cdot k
q⋅k的均值
E
(
q
⋅
k
)
=
0
E(q \cdot k)=0
E(q⋅k)=0,方差
D
(
q
⋅
k
)
=
d
k
D(q \cdot k)=d_{k}
D(q⋅k)=dk。方差越大也就说明,点积的数量级越大(以越大的概率取大值)。那么一个自然的做法就是把方差稳定到1,做法是将点积除以
d
k
\sqrt{d}_{k}
dk,这样有:
D
(
q
⋅
k
d
k
)
=
d
k
(
d
k
)
2
=
1
D\left(\frac{q \cdot k}{\sqrt{d}_{k}}\right)=\frac{d_{k}}{\left(\sqrt{d}_{k}\right)^{2}}=1
D(dkq⋅k)=(dk)2dk=1 将方差控制为1,也就有效地控制了前面提到的梯度消失的问题。
可以参考一下。水平有限,如果有误请指出。