transformer中的attention为什么scaled?

  原文链接:transformer中的attention为什么scaled?——LinT的回答

  这个问题困扰良久,一直没研究清楚,只知道个大概,不知其所以然,这里专门开一篇总结一下。由于有人珠玉在前,写得极其精彩,所以直接转载了,以下为原文。

———————————————————————————————————————————————

  谢邀。非常有意义的问题,我思考了好久,按照描述中的两个问题分点回答一下。

1. 为什么比较大的输入会使得softmax的梯度变得很小?

  对于一个输入向量 x ∈ R d \mathbf{x} \in \mathbb{R}^{d} xRd ,softmax函数将其映射/归一化到一个分布 y ^ ∈ R d \hat{\mathbf{y}} \in \mathbb{R}^{d} y^Rd。在这个过程中,softmax先用一个自然底数 e e e将输入中的元素间差距先“拉大”,然后归一化为一个分布。假设某个输入 x x x中最大的的元素下标是 k k k,如果输入的数量级变大(每个元素都很大),那么 y ^ k \hat{y}_{k} y^k会非常接近1。

  我们可以用一个小例子来看看 x x x的数量级对输入最大元素对应的预测概率 y ^ k \hat{y}_{k} y^k的影响。假定输入 x = [ a , a , 2 a ] ⊤ \mathbf{x}=[a, a, 2 a]^{\top} x=[a,a,2a]),我们来看不同量级的 a a a产生的 y ^ 3 \hat{y}_{3} y^3有什么区别。

  • a = 1 a=1 a=1时, y ^ 3 = 0.5761168847658291 \hat{y}_{3}=0.5761168847658291 y^3=0.5761168847658291
  • a = 10 a=10 a=10时, y ^ 3 = 0.999909208384341 \hat{y}_{3}=0.999909208384341 y^3=0.999909208384341
  • a = 100 a=100 a=100时, y ^ 3 ≈ 1.0 \hat{y}_{3} \approx 1.0 y^31.0(计算精度限制);

  我们不妨把 a a a在不同取值下,对应的 y ^ 3 \hat{y}_{3} y^3全部绘制出来。代码如下:

from math import exp
from matplotlib import pyplot as plt
import numpy as np 
f = lambda x: exp(x * 2) / (exp(x) + exp(x) + exp(x * 2))
x = np.linspace(0, 100, 100)
y_3 = [f(x_i) for x_i in x]
plt.plot(x, y_3)
plt.show()

  得到的图如下所示:
在这里插入图片描述
  可以看到,数量级对softmax得到的分布影响非常大。在数量级较大时,softmax将几乎全部的概率分布都分配给了最大值对应的标签。

  然后我们来看softmax的梯度。不妨简记softmax函数为 g ( ⋅ ) g(\cdot) g(),softmax得到的分布向量 y ^ = g ( x ) \hat{\mathbf{y}}=g(\mathbf{x}) y^=g(x)对输入 x x x的梯度为:
∂ g ( x ) ∂ x = diag ⁡ ( y ^ ) − y ^ y ^ ⊤ ∈ R d × d \frac{\partial g(\mathbf{x})}{\partial \mathbf{x}}=\operatorname{diag}(\hat{\mathbf{y}})-\hat{\mathbf{y}} \hat{\mathbf{y}}^{\top} \quad \in \mathbb{R}^{d \times d} xg(x)=diag(y^)y^y^Rd×d  把这个矩阵展开:
∂ g ( x ) ∂ x = [ y ^ 1 0 ⋯ 0 0 y ^ 2 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ y ^ d ] − [ y ^ 1 2 y ^ 1 y ^ 2 ⋯ y ^ 1 y ^ d y ^ 2 y ^ 1 y ^ 2 2 ⋯ y ^ 2 y ^ d ⋮ ⋮ ⋱ ⋮ y ^ d y ^ 1 y ^ d y ^ 2 ⋯ y ^ d 2 ] \frac{\partial g(\mathbf{x})}{\partial \mathbf{x}}=\left[\begin{array}{cccc} \hat{y}_{1} & 0 & \cdots & 0 \\ 0 & \hat{y}_{2} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \hat{y}_{d} \end{array}\right]-\left[\begin{array}{cccc} \hat{y}_{1}^{2} & \hat{y}_{1} \hat{y}_{2} & \cdots & \hat{y}_{1} \hat{y}_{d} \\ \hat{y}_{2} \hat{y}_{1} & \hat{y}_{2}^{2} & \cdots & \hat{y}_{2} \hat{y}_{d} \\ \vdots & \vdots & \ddots & \vdots \\ \hat{y}_{d} \hat{y}_{1} & \hat{y}_{d} \hat{y}_{2} & \cdots & \hat{y}_{d}^{2} \end{array}\right] xg(x)= y^1000y^2000y^d y^12y^2y^1y^dy^1y^1y^2y^22y^dy^2y^1y^dy^2y^dy^d2   根据前面的讨论,当输入 x x x的元素均较大时,softmax会把大部分概率分布分配给最大的元素,假设我们的输入数量级很大,最大的元素是 x 1 x_1 x1,那么就将产生一个接近one-hot的向量 y ^ ≈ [ 1 , 0 , ⋯   , 0 ] ⊤ \hat{\mathbf{y}} \approx[1,0, \cdots, 0]^{\top} y^[1,0,,0],此时上面的矩阵变为如下形式:
∂ g ( x ) ∂ x ≈ [ 1 0 ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ 0 ] − [ 1 0 ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ 0 ] = 0 \frac{\partial g(\mathbf{x})}{\partial \mathbf{x}} \approx\left[\begin{array}{cccc} 1 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 0 \end{array}\right]-\left[\begin{array}{cccc} 1 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 0 \end{array}\right]=\mathbf{0} xg(x) 100000000 100000000 =0  也就是说,在输入的数量级很大时,梯度消失为0,造成参数更新困难。

  注: softmax的梯度可以自行推导,网络上也有很多推导可以参考。

2. 维度与点积大小的关系是怎么样的,为什么使用维度的根号来放缩?

  针对为什么维度会影响点积的大小,在论文的脚注中其实给出了一点解释:
在这里插入图片描述  假设向量 q q q k k k的各个分量是互相独立的随机变量,均值是0,方差是1,那么点积 q ⋅ k q \cdot k qk的均值是0,方差是 d k d_k dk。这里我给出一点更详细的推导:

  对 ∀ i = 1 , ⋯   , d k \forall i=1, \cdots, d_{k} i=1,,dk q i q_i qi k i k_i ki都是随机变量,为了方便书写,不妨记 X = q i X=q_i X=qi Y = k i Y=k_i Y=ki。这样有: D ( X ) = D ( Y ) = 1 D(X)=D(Y)=1 D(X)=D(Y)=1 E ( X ) = E ( Y ) = 0 E(X)=E(Y)=0 E(X)=E(Y)=0。则:

  1. E ( X Y ) = E ( X ) E ( Y ) = 0 × 0 = 0 E(X Y)=E(X) E(Y)=0 \times 0=0 E(XY)=E(X)E(Y)=0×0=0
  2. D ( X Y ) = E ( X 2 ⋅ Y 2 ) − [ E ( X Y ) ] 2 = E ( X 2 ) E ( Y 2 ) − [ E ( X ) E ( Y ) ] 2 = E ( X 2 − 0 2 ) E ( Y 2 − 0 2 ) − [ E ( X ) E ( Y ) ] 2 = E ( X 2 − [ E ( X ) ] 2 ) E ( Y 2 − [ E ( Y ) ] 2 ) − [ E ( X ) E ( Y ) ] 2 = D ( X ) D ( Y ) − [ E ( X ) E ( Y ) ] 2 = 1 × 1 − ( 0 × 0 ) 2 = 1 \begin{aligned} D(X Y) &=E\left(X^{2} \cdot Y^{2}\right)-[E(X Y)]^{2} \\ &=E\left(X^{2}\right) E\left(Y^{2}\right)-[E(X) E(Y)]^{2} \\ &=E\left(X^{2}-0^{2}\right) E\left(Y^{2}-0^{2}\right)-[E(X) E(Y)]^{2} \\ &=E\left(X^{2}-[E(X)]^{2}\right) E\left(Y^{2}-[E(Y)]^{2}\right)-[E(X) E(Y)]^{2} \\ &=D(X) D(Y)-[E(X) E(Y)]^{2} \\ &=1 \times 1-(0 \times 0)^{2} \\ &=1 \end{aligned} D(XY)=E(X2Y2)[E(XY)]2=E(X2)E(Y2)[E(X)E(Y)]2=E(X202)E(Y202)[E(X)E(Y)]2=E(X2[E(X)]2)E(Y2[E(Y)]2)[E(X)E(Y)]2=D(X)D(Y)[E(X)E(Y)]2=1×1(0×0)2=1

  这样 ∀ i = 1 , ⋯   , d k \forall i=1, \cdots, d_{k} i=1,,dk q i ⋅ k i q_i \cdot k_i qiki的均值是0,方差是1,又由期望和方差的性质, 对相互独立的分量 z i z_i zi,有
E ( ∑ i Z i ) = ∑ i E ( Z i ) E\left(\sum_{i} Z_{i}\right)=\sum_{i} E\left(Z_{i}\right) E(iZi)=iE(Zi)
  以及
D ( ∑ i Z i ) = ∑ i D ( Z i ) D\left(\sum_{i} Z_{i}\right)=\sum_{i} D\left(Z_{i}\right) D(iZi)=iD(Zi)
  所以有 q ⋅ k q \cdot k qk的均值 E ( q ⋅ k ) = 0 E(q \cdot k)=0 E(qk)=0,方差 D ( q ⋅ k ) = d k D(q \cdot k)=d_{k} D(qk)=dk。方差越大也就说明,点积的数量级越大(以越大的概率取大值)。那么一个自然的做法就是把方差稳定到1,做法是将点积除以 d k \sqrt{d}_{k} d k,这样有:
D ( q ⋅ k d k ) = d k ( d k ) 2 = 1 D\left(\frac{q \cdot k}{\sqrt{d}_{k}}\right)=\frac{d_{k}}{\left(\sqrt{d}_{k}\right)^{2}}=1 D(d kqk)=(d k)2dk=1  将方差控制为1,也就有效地控制了前面提到的梯度消失的问题。

  可以参考一下。水平有限,如果有误请指出。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值