Attention的矩阵表示及理解

Attention两篇文章链接:其中一个是Luong,提的dot product attention, 另一个是Vaswali的scaled dot product attention , 也就是大名鼎鼎的attention is all you need。

说到attention不再过多赘述,论文中的公式推导感觉比较简单,结合自己的理解写一下矩阵层面的表示。数学好的可以跳过。

在attention is all you need这篇文章中,他是这么写的:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V,而Luong那篇文章中,公式比较多且分散。
但无论如何,大致总结是(看下图),先算attention score(Q,K相乘),再用softmax算distribution,再把distribution和hidden state相乘获得attention output(最上面那个MatMul),再把output和另一个hidden相加(concat)。在这里插入图片描述

这里以seq2seq模型中的attention为例。

首先我们有encoder hidden state的一个序列: H = [ h 1 , h 2 , . . . h N ] H =[h_1, h_2, ...h_N] H=[h1,h2,...hN]
然后有 t t t 时刻的decoder state s t s^t st, 所有时刻的decoder state就是 S = [ s 1 , s 2 , . . . , s t ] S=[s^1, s^2, ..., s^t] S=[s1,s2,...,st]
每次用所有的encoder hidden state去和当前时刻的decoder state相乘(dot product)

对于 t t t时刻而言的attention score就是用 e t = [ h 1 T s t , h 2 T s t , . . . , h N T s t ] e^t = [h_1^Ts^t, h_2^Ts^t, ..., h_N^Ts^t] et=[h1Tst,h2Tst,...,hNTst]
但实际在计算中,我们是把整个decoder hidden state和encoder hidden state乘起来,而不是像循环一样对每个时刻都依次计算
E = [ h 1 T s 1 h 2 T s 1 , . . . , h N T s 1 ⋮ ⋮ ⋱ ⋮ h 1 T s t − 1 h 2 T s t − 1 , . . . , h N T s t − 1 h 1 T s t h 2 T s t , . . . , h N T s t ] = [ s 1 s 2 ⋮ s t ] ⋅ [ h 1 T h 2 T ⋯ h N T ]   ( 1 ) E = \left[ \begin{matrix} h_1^Ts^1 & h_2^Ts^1, &..., &h_N^Ts^1 \\ \vdots & \vdots &\ddots &\vdots\\ h_1^Ts^{t-1} & h_2^Ts^{t-1}, &..., &h_N^Ts^{t-1}\\ h_1^Ts^t & h_2^Ts^t, &..., &h_N^Ts^t \end{matrix} \right] = \left[ \begin{matrix}s^1\\ s^2\\ \vdots\\ s^t \end{matrix} \right] \cdot \left[ \begin{matrix} h_1^T & h_2^T & \cdots & & h_N^T \end{matrix} \right] \ (1) E=h1Ts1h1Tst1h1Tsth2Ts1,h2Tst1,h2Tst,...,...,...,hNTs1hNTst1hNTst=s1s2st[h1Th2ThNT] (1)

这里看不懂没关系,先看下一步如何计算attention distribution的计算,这里其实就是用softmax把 e t e^t et中的每个元素映射到[0,1]之间并且总和为1:
α t = s o f t m a x ( e t ) \alpha^t = softmax(e^t) αt=softmax(et)

如果用矩阵的方式表示其实就是
s o f t m a x ( E ) = [ α 1 , α 2 , . . . , α t ] T softmax(E) = [\alpha^1, \alpha^2, ..., \alpha^t]^T softmax(E)=[α1,α2,...,αt]T 其中e的右上标是时刻 t t t的意思, 而不是指数。其中每一个 α \alpha α都是一个1堆概率,长度为N,,也就是encoder hidden state的长度,并且每一个 α \alpha α总和为1。

下一步是将attention distribution和encoder hidden state相乘获得attention output, 即把每一个 α \alpha α中的每一个概率和对应位置的encoder hidden state相乘,再求和:
a t = Σ i = 1 N α i t h i = [ α 1 t , α 2 t , α 3 t , . . . α i t ] ⋅ [ h 1 h 2 h 3 ⋮ h i ] a_t = \Sigma_{i=1}^N\alpha_i^th_i = [\alpha_1^t, \alpha_2^t, \alpha_3^t, ... \alpha_i^t] \cdot \left[ \begin{matrix}h_1 \\ h_2\\ h_3\\ \vdots\\ h_i\end{matrix} \right] at=Σi=1Nαithi=[α1t,α2t,α3t,...αit]h1h2h3hi
这个操作对应于下图的encoder recurrent layer指向attention distribution的那条红线,以及Attention distribution指向Attention output的黑线以及最上面的三角形。
在这里插入图片描述
把每个时刻的 a t a_t at都放到一起的话:
A = [ a 1 , a 2 , a 3 , . . . , a t ] T = [ a 1 a 2 ⋮ a t ] = [ α 1 1 α 2 1 ⋯ α i 1 α 1 2 α 2 2 ⋯ α i 2 ⋮ ⋮ ⋱ ⋮ α 1 t α 2 t ⋯ α i t ] ⋅ [ h 1 h 2 ⋮ h i ] = S o f t m a x ( E ) ⋅ H T = S o f t m a x ( S ⋅ H T ) ⋅ H T A = [a_1, a_2, a_3, ..., a_t]^T = \left[ \begin{matrix} a_1\\ a_2\\ \vdots\\ a_t \end{matrix} \right] = \left[ \begin{matrix} \alpha_1^1 & \alpha_2^1 & \cdots & \alpha_i^1\\ \alpha_1^2 & \alpha_2^2 & \cdots & \alpha_i^2\\ \vdots & \vdots & \ddots& \vdots\\ \alpha_1^t & \alpha_2^t & \cdots & \alpha_i^t\end{matrix} \right] \cdot \left[ \begin{matrix} h_1 \\h_2 \\ \vdots\\ h_i \end{matrix} \right] \\=Softmax(E) \cdot H^T \\=Softmax(S \cdot H^T) \cdot H^T A=[a1,a2,a3,...,at]T=a1a2at=α11α12α1tα21α22α2tαi1αi2αith1h2hi=Softmax(E)HT=Softmax(SHT)HT

A中的每一个 a t a_t at是当前 t t t时刻的attention output,需要和对应时刻的decoder state的s^t执行concat操作。如下图
在这里插入图片描述
写成公式就是 c o n c a t [ a t , s t ] concat[a_t, s_t] concat[at,st],然后对每一个时刻的decoder state都执行这样的操作,其实就是concat(A, S),其中A就是 [ a 1 , a 2 , . . . , a t ] = [ Σ i = 1 N α i 1 h i Σ i = 1 N α i 2 h i . . . Σ i = 1 N α i t h i ] [a_1, a_2, ..., a_t] = \left[ \begin{matrix}\Sigma_{i=1}^N\alpha_i^1h_i & \Sigma_{i=1}^N\alpha_i^2h_i & ... &\Sigma_{i=1}^N\alpha_i^th_i\end{matrix} \right] [a1,a2,...,at]=[Σi=1Nαi1hiΣi=1Nαi2hi...Σi=1Nαithi], 而S就是 [ s 1 , s 2 , . . . , s t ] [s_1, s_2, ..., s_t] [s1,s2,...,st]

所以最后的attention就等于 c o n c a t ( S o f t m a x ( S ⋅ H T ) ⋅ H T , S ) concat(Softmax(S \cdot H^T) \cdot H^T, S) concat(Softmax(SHT)HT,S)

翻译成代码,我们只需要知道S,H,并且写一个矩阵乘法,softmax和concat即可,在pytorch中分别可以通过torch.bmm(), F.softmax(), torch.cat()实现。

遇到不同的attention score计算方式,只需要修改Softmax中的计算即可,比如我要用scaled dot product, softmax层里面就是 S ⋅ H T / n S\cdot H^T/\sqrt{n} SHT/n 其中n是input的长度, 其他地方不需要修改。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值