Attention两篇文章链接:其中一个是Luong,提的dot product attention, 另一个是Vaswali的scaled dot product attention , 也就是大名鼎鼎的attention is all you need。
说到attention不再过多赘述,论文中的公式推导感觉比较简单,结合自己的理解写一下矩阵层面的表示。数学好的可以跳过。
在attention is all you need这篇文章中,他是这么写的:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V
Attention(Q,K,V)=softmax(dkQKT)V,而Luong那篇文章中,公式比较多且分散。
但无论如何,大致总结是(看下图),先算attention score(Q,K相乘),再用softmax算distribution,再把distribution和hidden state相乘获得attention output(最上面那个MatMul),再把output和另一个hidden相加(concat)。
这里以seq2seq模型中的attention为例。
首先我们有encoder hidden state的一个序列:
H
=
[
h
1
,
h
2
,
.
.
.
h
N
]
H =[h_1, h_2, ...h_N]
H=[h1,h2,...hN]
然后有
t
t
t 时刻的decoder state
s
t
s^t
st, 所有时刻的decoder state就是
S
=
[
s
1
,
s
2
,
.
.
.
,
s
t
]
S=[s^1, s^2, ..., s^t]
S=[s1,s2,...,st]。
每次用所有的encoder hidden state去和当前时刻的decoder state相乘(dot product)
对于
t
t
t时刻而言的attention score就是用
e
t
=
[
h
1
T
s
t
,
h
2
T
s
t
,
.
.
.
,
h
N
T
s
t
]
e^t = [h_1^Ts^t, h_2^Ts^t, ..., h_N^Ts^t]
et=[h1Tst,h2Tst,...,hNTst],
但实际在计算中,我们是把整个decoder hidden state和encoder hidden state乘起来,而不是像循环一样对每个时刻都依次计算
E
=
[
h
1
T
s
1
h
2
T
s
1
,
.
.
.
,
h
N
T
s
1
⋮
⋮
⋱
⋮
h
1
T
s
t
−
1
h
2
T
s
t
−
1
,
.
.
.
,
h
N
T
s
t
−
1
h
1
T
s
t
h
2
T
s
t
,
.
.
.
,
h
N
T
s
t
]
=
[
s
1
s
2
⋮
s
t
]
⋅
[
h
1
T
h
2
T
⋯
h
N
T
]
(
1
)
E = \left[ \begin{matrix} h_1^Ts^1 & h_2^Ts^1, &..., &h_N^Ts^1 \\ \vdots & \vdots &\ddots &\vdots\\ h_1^Ts^{t-1} & h_2^Ts^{t-1}, &..., &h_N^Ts^{t-1}\\ h_1^Ts^t & h_2^Ts^t, &..., &h_N^Ts^t \end{matrix} \right] = \left[ \begin{matrix}s^1\\ s^2\\ \vdots\\ s^t \end{matrix} \right] \cdot \left[ \begin{matrix} h_1^T & h_2^T & \cdots & & h_N^T \end{matrix} \right] \ (1)
E=⎣⎢⎢⎢⎡h1Ts1⋮h1Tst−1h1Tsth2Ts1,⋮h2Tst−1,h2Tst,...,⋱...,...,hNTs1⋮hNTst−1hNTst⎦⎥⎥⎥⎤=⎣⎢⎢⎢⎡s1s2⋮st⎦⎥⎥⎥⎤⋅[h1Th2T⋯hNT] (1)
这里看不懂没关系,先看下一步如何计算attention distribution的计算,这里其实就是用softmax把
e
t
e^t
et中的每个元素映射到[0,1]之间并且总和为1:
α
t
=
s
o
f
t
m
a
x
(
e
t
)
\alpha^t = softmax(e^t)
αt=softmax(et)
如果用矩阵的方式表示其实就是
s
o
f
t
m
a
x
(
E
)
=
[
α
1
,
α
2
,
.
.
.
,
α
t
]
T
softmax(E) = [\alpha^1, \alpha^2, ..., \alpha^t]^T
softmax(E)=[α1,α2,...,αt]T 其中e的右上标是时刻
t
t
t的意思, 而不是指数。其中每一个
α
\alpha
α都是一个1堆概率,长度为N,,也就是encoder hidden state的长度,并且每一个
α
\alpha
α总和为1。
下一步是将attention distribution和encoder hidden state相乘获得attention output, 即把每一个
α
\alpha
α中的每一个概率和对应位置的encoder hidden state相乘,再求和:
a
t
=
Σ
i
=
1
N
α
i
t
h
i
=
[
α
1
t
,
α
2
t
,
α
3
t
,
.
.
.
α
i
t
]
⋅
[
h
1
h
2
h
3
⋮
h
i
]
a_t = \Sigma_{i=1}^N\alpha_i^th_i = [\alpha_1^t, \alpha_2^t, \alpha_3^t, ... \alpha_i^t] \cdot \left[ \begin{matrix}h_1 \\ h_2\\ h_3\\ \vdots\\ h_i\end{matrix} \right]
at=Σi=1Nαithi=[α1t,α2t,α3t,...αit]⋅⎣⎢⎢⎢⎢⎢⎡h1h2h3⋮hi⎦⎥⎥⎥⎥⎥⎤
这个操作对应于下图的encoder recurrent layer指向attention distribution的那条红线,以及Attention distribution指向Attention output的黑线以及最上面的三角形。
把每个时刻的
a
t
a_t
at都放到一起的话:
A
=
[
a
1
,
a
2
,
a
3
,
.
.
.
,
a
t
]
T
=
[
a
1
a
2
⋮
a
t
]
=
[
α
1
1
α
2
1
⋯
α
i
1
α
1
2
α
2
2
⋯
α
i
2
⋮
⋮
⋱
⋮
α
1
t
α
2
t
⋯
α
i
t
]
⋅
[
h
1
h
2
⋮
h
i
]
=
S
o
f
t
m
a
x
(
E
)
⋅
H
T
=
S
o
f
t
m
a
x
(
S
⋅
H
T
)
⋅
H
T
A = [a_1, a_2, a_3, ..., a_t]^T = \left[ \begin{matrix} a_1\\ a_2\\ \vdots\\ a_t \end{matrix} \right] = \left[ \begin{matrix} \alpha_1^1 & \alpha_2^1 & \cdots & \alpha_i^1\\ \alpha_1^2 & \alpha_2^2 & \cdots & \alpha_i^2\\ \vdots & \vdots & \ddots& \vdots\\ \alpha_1^t & \alpha_2^t & \cdots & \alpha_i^t\end{matrix} \right] \cdot \left[ \begin{matrix} h_1 \\h_2 \\ \vdots\\ h_i \end{matrix} \right] \\=Softmax(E) \cdot H^T \\=Softmax(S \cdot H^T) \cdot H^T
A=[a1,a2,a3,...,at]T=⎣⎢⎢⎢⎡a1a2⋮at⎦⎥⎥⎥⎤=⎣⎢⎢⎢⎡α11α12⋮α1tα21α22⋮α2t⋯⋯⋱⋯αi1αi2⋮αit⎦⎥⎥⎥⎤⋅⎣⎢⎢⎢⎡h1h2⋮hi⎦⎥⎥⎥⎤=Softmax(E)⋅HT=Softmax(S⋅HT)⋅HT
A中的每一个
a
t
a_t
at是当前
t
t
t时刻的attention output,需要和对应时刻的decoder state的s^t执行concat操作。如下图
写成公式就是
c
o
n
c
a
t
[
a
t
,
s
t
]
concat[a_t, s_t]
concat[at,st],然后对每一个时刻的decoder state都执行这样的操作,其实就是concat(A, S),其中A就是
[
a
1
,
a
2
,
.
.
.
,
a
t
]
=
[
Σ
i
=
1
N
α
i
1
h
i
Σ
i
=
1
N
α
i
2
h
i
.
.
.
Σ
i
=
1
N
α
i
t
h
i
]
[a_1, a_2, ..., a_t] = \left[ \begin{matrix}\Sigma_{i=1}^N\alpha_i^1h_i & \Sigma_{i=1}^N\alpha_i^2h_i & ... &\Sigma_{i=1}^N\alpha_i^th_i\end{matrix} \right]
[a1,a2,...,at]=[Σi=1Nαi1hiΣi=1Nαi2hi...Σi=1Nαithi], 而S就是
[
s
1
,
s
2
,
.
.
.
,
s
t
]
[s_1, s_2, ..., s_t]
[s1,s2,...,st]
所以最后的attention就等于 c o n c a t ( S o f t m a x ( S ⋅ H T ) ⋅ H T , S ) concat(Softmax(S \cdot H^T) \cdot H^T, S) concat(Softmax(S⋅HT)⋅HT,S)
翻译成代码,我们只需要知道S,H,并且写一个矩阵乘法,softmax和concat即可,在pytorch中分别可以通过torch.bmm(), F.softmax(), torch.cat()实现。
遇到不同的attention score计算方式,只需要修改Softmax中的计算即可,比如我要用scaled dot product, softmax层里面就是 S ⋅ H T / n S\cdot H^T/\sqrt{n} S⋅HT/n其中n是input的长度, 其他地方不需要修改。