Attention is all you need 公式推导
前言
Transformer的根源在于这篇文章,但这篇paper本身写的并不好懂,因为省去了大量的细节。依照上交许志钦老师的讲解才理清头绪,所以我准备以公式推导的方式记录下来这篇文章的流程。
并没有看代码,毕竟只是作为研究DETR的准备工作,听完许老师的课理论准备就够了。所以下面也都是基于对许老师课程的记录,记录和理解有不准确的地方还望大家指正。
一、以训练的角度看
整篇笔记都是按照这张图的符号约定进行的。
(一)输入的内容编码
X ˉ ∈ R n × N \bar{X} \in \mathbb{R}^{n \times N} Xˉ∈Rn×N表示经one-hot变换后的输入, n n n表示一句话中的单词数量, N N N表示输入字典大小。关于字典是什么意思请详见one-hot编码。
为了Batch操作,会将输入小于n的句子进行padding,让其长度等于n。
one-hot编码过于冗余,过于稀疏,即绝大部分位置都是0,只有使用的单词对应位置是1,因此效率不高。可以利用单词含义之间的相关性,使用一个更高效地表示。
X
~
=
X
ˉ
W
∈
R
n
×
d
m
\tilde{X} = \bar{X} W \in \mathbb{R}^{n \times d_m}
X~=XˉW∈Rn×dm
其中
d
m
d_m
dm小于
N
N
N,也就是,最终我们使用一个长度为
d
m
d_m
dm的向量表示一个单词 ,
W
W
W可训练。
(二)输入的位置编码
位置编码positional encoding是为了考虑一句话中单词位置对于翻译的影响。使用sin的编码方式是想达到:单词之间在位置上的相关性,只依赖于两个单词的相对位置,而不受绝对位置影响。 而两个单词位置相关性,就是对应位置编码的內积。
必须要清楚的是,位置编码只受单词位置的影响,而不受单词含义的影响
于是,位置编码后的矩阵
P
E
∈
R
n
×
d
m
PE \in \mathbb{R}^{n \times d_m}
PE∈Rn×dm,
P
E
PE
PE中的元素
P
E
(
p
o
s
,
i
)
PE(pos,i)
PE(pos,i)具体为
P
E
(
p
o
s
,
2
i
)
=
s
i
n
(
p
o
s
/
1000
0
2
i
/
d
m
)
P
E
(
p
o
s
,
2
i
+
1
)
=
c
o
s
(
p
o
s
/
1000
0
2
i
/
d
m
)
PE(pos,2i) = sin(pos/10000^{2i/d_m}) \\ PE(pos,2i+1) = cos(pos/10000^{2i/d_m})
PE(pos,2i)=sin(pos/100002i/dm)PE(pos,2i+1)=cos(pos/100002i/dm)
其中,pos表示一句话中的第pos个单词,i是该单词的第i个维度
最终,位置编码和内容编码加起来,终于得到了Transformer的输入
X
X
X,
X
=
P
E
+
X
~
X = PE+ \tilde{X}
X=PE+X~
单头的attention
Q
=
X
W
Q
K
=
X
W
K
V
=
X
W
V
Q=XW^Q \\ K=XW^K \\ V=XW^V
Q=XWQK=XWKV=XWV
其中,
Q
,
K
,
V
∈
R
n
×
d
m
Q,K,V \in \mathbb{R}^{n \times d_m}
Q,K,V∈Rn×dm
一句话中某个单词和其他单词的相关性可以表示成
Q
K
T
QK^T
QKT,然后再稍加处理
A
=
s
o
f
t
m
a
x
(
Q
K
T
d
m
)
A = softmax(\frac{QK^T}{\sqrt {d_m}})
A=softmax(dmQKT)
这里有三点需要注意:
- 除以 d m \sqrt {d_m} dm 是为了阻止 Q K T QK^T QKT过大,防止softmax之后,相关的地方很大,不太相关的地方近似为0。若不防止此情况,产生的梯度会叫较小,不利于训练。(这似乎更多是从实践的角度得到的结论,并没有很多的理论依据)
- softmax是对矩阵的每行独立进行的,即对每个单词独立进行的。
- 前面提到为了是所有输入的句子长度相等,进行了padding操作。但你不会希望句子中的某个单词和padding的值有相关性,在此处进行了mask处理。
- 比如padding的值为0,那么在softmax中 e 0 e^0 e0不等于0,这就产生了相关性
- 理想情况是padding处使用 − ∞ -\infty −∞,( e − ∞ = 0 e^{-\infty}=0 e−∞=0),程序中使用-1e9
在计算完相关性之后,在得到输出
X
[
1
]
=
A
V
X^{[1]}=AV
X[1]=AV
即
X
[
1
]
=
s
o
f
t
m
a
x
(
Q
K
T
d
m
)
⋅
V
X^{[1]}=softmax(\frac{QK^T}{\sqrt {d_m}}) \cdot V
X[1]=softmax(dmQKT)⋅V
multi-head attention
也就是流程图中使用的版本,和单头的版本类似于组卷积和卷积关系。
首先确定每个head的维度,一共h个head
d
q
=
d
k
=
d
v
=
d
m
/
h
d_q = d_k = d_v = d_m/h
dq=dk=dv=dm/h
第i个head中QKV的计算方式
Q
i
=
Q
W
i
Q
K
i
=
K
W
i
K
V
i
=
V
W
i
V
X
~
i
=
s
o
f
t
m
a
x
(
Q
i
K
i
T
d
m
/
h
)
⋅
V
i
Q_i=QW^Q_i \\ K_i=KW^K_i \\ V_i=VW^V_i \\ \tilde{X}_i = softmax(\frac{Q_iK_i^T}{\sqrt {d_m/h}}) \cdot V_i
Qi=QWiQKi=KWiKVi=VWiVX~i=softmax(dm/hQiKiT)⋅Vi
其中,
Q
i
,
K
i
,
V
i
,
X
~
i
∈
R
n
×
(
d
m
/
h
)
Q_i,K_i,V_i, \tilde{X}_i \in \mathbb{R}^{n \times (d_m/h)}
Qi,Ki,Vi,X~i∈Rn×(dm/h)。
X
~
i
\tilde{X}_i
X~i 只是一个中间量,但注意计算时候若有padding需进行mask操作。然后将所有head的输出合并起来
X
~
=
[
X
~
1
,
.
.
.
X
~
i
,
.
.
.
,
X
~
h
]
∈
R
n
×
d
m
X
[
1
]
=
X
~
W
∈
R
n
×
d
m
\tilde{X} = [\tilde{X}_1,...\tilde{X}_i,...,\tilde{X}_h] \in \mathbb{R}^{n \times d_m} \\ X^{[1]} = \tilde{X} W \in \mathbb{R}^{n \times d_m}
X~=[X~1,...X~i,...,X~h]∈Rn×dmX[1]=X~W∈Rn×dm
注:程序中生成
Q
i
,
K
i
,
V
i
Q_i,K_i,V_i
Qi,Ki,Vi的做法是直接对
Q
,
K
,
V
Q,K,V
Q,K,V进行划分。实在是不想用电脑画图了。
Encoder的后续
X
[
2
]
=
X
[
1
]
+
X
X
[
3
]
=
L
a
y
e
r
N
o
r
m
(
X
[
2
]
)
X^{[2]} = X^{[1]} + X \\ X^{[3]} = LayerNorm( X^{[2]} )
X[2]=X[1]+XX[3]=LayerNorm(X[2])
几种Norm的对比
X
[
4
]
=
F
F
N
(
X
[
3
]
)
X^{[4]} = FFN( X^{[3]} )
X[4]=FFN(X[3])
其中FFN为Feed Forward Net,就是几个全连接层。然后使用ADD和Norm得到Encoder的输出
X
[
6
]
X^{[6]}
X[6]
Decoder
在训练时,
Y
ˉ
\bar{Y}
Yˉ就是groundtruth,
Y
Y
Y的生成方式可类比于
X
X
X。
重点在于Masked-multi-head-attention的理解。
假设输出的句子有10个单词,在Decoder确定第4个单词的时候,它只能依赖前3个单词,输出第5个的时候只能依赖前4个,以此类推。为了完成这个效果需要一个下三角矩阵(对角线上方为0),作用在相关性计算的结果上
A
~
i
o
=
Q
i
o
⋅
K
i
o
T
d
m
/
h
\tilde{A}_i^o = \frac{Q_i^o \cdot {K_i^o}^T}{\sqrt {d_m/h}}
A~io=dm/hQio⋅KioT
而
A
~
i
o
\tilde{A}_i^o
A~io与mask元素相乘后可以达到上述效果
Y
~
i
=
s
o
f
t
m
a
x
(
Q
i
o
⋅
K
i
o
T
d
m
/
h
⊙
m
a
s
k
T
)
⋅
V
i
o
\tilde{Y}_i = softmax(\frac{Q_i^o \cdot {K_i^o}^T}{\sqrt {d_m/h}} \odot mask^T ) \cdot V_i^o
Y~i=softmax(dm/hQio⋅KioT⊙maskT)⋅Vio
这是以多头的形式写的公式,mask为下三角矩阵,
Q
i
o
,
K
i
o
,
V
i
o
Q_i^o , K_i^o ,V_i^o
Qio,Kio,Vio根据
Y
Y
Y生成。