Attention is all you need 公式推导

前言

Transformer的根源在于这篇文章,但这篇paper本身写的并不好懂,因为省去了大量的细节。依照上交许志钦老师的讲解才理清头绪,所以我准备以公式推导的方式记录下来这篇文章的流程。

并没有看代码,毕竟只是作为研究DETR的准备工作,听完许老师的课理论准备就够了。所以下面也都是基于对许老师课程的记录,记录和理解有不准确的地方还望大家指正。

一、以训练的角度看

整篇笔记都是按照这张图的符号约定进行的。
在这里插入图片描述

(一)输入的内容编码

X ˉ ∈ R n × N \bar{X} \in \mathbb{R}^{n \times N} XˉRn×N表示经one-hot变换后的输入, n n n表示一句话中的单词数量, N N N表示输入字典大小。关于字典是什么意思请详见one-hot编码。

为了Batch操作,会将输入小于n的句子进行padding,让其长度等于n。

one-hot编码过于冗余,过于稀疏,即绝大部分位置都是0,只有使用的单词对应位置是1,因此效率不高。可以利用单词含义之间的相关性,使用一个更高效地表示。
X ~ = X ˉ W ∈ R n × d m \tilde{X} = \bar{X} W \in \mathbb{R}^{n \times d_m} X~=XˉWRn×dm
其中 d m d_m dm小于 N N N也就是,最终我们使用一个长度为 d m d_m dm的向量表示一个单词 W W W可训练。

(二)输入的位置编码

位置编码positional encoding是为了考虑一句话中单词位置对于翻译的影响。使用sin的编码方式是想达到:单词之间在位置上的相关性,只依赖于两个单词的相对位置,而不受绝对位置影响。 而两个单词位置相关性,就是对应位置编码的內积。

必须要清楚的是,位置编码只受单词位置的影响,而不受单词含义的影响

于是,位置编码后的矩阵 P E ∈ R n × d m PE \in \mathbb{R}^{n \times d_m} PERn×dm P E PE PE中的元素 P E ( p o s , i ) PE(pos,i) PE(pos,i)具体为
P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d m ) P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i / d m ) PE(pos,2i) = sin(pos/10000^{2i/d_m}) \\ PE(pos,2i+1) = cos(pos/10000^{2i/d_m}) PE(pos,2i)=sin(pos/100002i/dm)PE(pos,2i+1)=cos(pos/100002i/dm)
其中,pos表示一句话中的第pos个单词,i是该单词的第i个维度

最终,位置编码和内容编码加起来,终于得到了Transformer的输入 X X X,
X = P E + X ~ X = PE+ \tilde{X} X=PE+X~

单头的attention

Q = X W Q K = X W K V = X W V Q=XW^Q \\ K=XW^K \\ V=XW^V Q=XWQK=XWKV=XWV
其中, Q , K , V ∈ R n × d m Q,K,V \in \mathbb{R}^{n \times d_m} Q,K,VRn×dm

一句话中某个单词和其他单词的相关性可以表示成 Q K T QK^T QKT,然后再稍加处理
A = s o f t m a x ( Q K T d m ) A = softmax(\frac{QK^T}{\sqrt {d_m}}) A=softmax(dm QKT)
这里有三点需要注意:

  1. 除以 d m \sqrt {d_m} dm 是为了阻止 Q K T QK^T QKT过大,防止softmax之后,相关的地方很大,不太相关的地方近似为0。若不防止此情况,产生的梯度会叫较小,不利于训练。(这似乎更多是从实践的角度得到的结论,并没有很多的理论依据)
  2. softmax是对矩阵的每行独立进行的,即对每个单词独立进行的。
  3. 前面提到为了是所有输入的句子长度相等,进行了padding操作。但你不会希望句子中的某个单词和padding的值有相关性,在此处进行了mask处理。
    1. 比如padding的值为0,那么在softmax中 e 0 e^0 e0不等于0,这就产生了相关性
    2. 理想情况是padding处使用 − ∞ -\infty ,( e − ∞ = 0 e^{-\infty}=0 e=0),程序中使用-1e9

在计算完相关性之后,在得到输出
X [ 1 ] = A V X^{[1]}=AV X[1]=AV
X [ 1 ] = s o f t m a x ( Q K T d m ) ⋅ V X^{[1]}=softmax(\frac{QK^T}{\sqrt {d_m}}) \cdot V X[1]=softmax(dm QKT)V

multi-head attention

也就是流程图中使用的版本,和单头的版本类似于组卷积和卷积关系。
首先确定每个head的维度,一共h个head
d q = d k = d v = d m / h d_q = d_k = d_v = d_m/h dq=dk=dv=dm/h
第i个head中QKV的计算方式
Q i = Q W i Q K i = K W i K V i = V W i V X ~ i = s o f t m a x ( Q i K i T d m / h ) ⋅ V i Q_i=QW^Q_i \\ K_i=KW^K_i \\ V_i=VW^V_i \\ \tilde{X}_i = softmax(\frac{Q_iK_i^T}{\sqrt {d_m/h}}) \cdot V_i Qi=QWiQKi=KWiKVi=VWiVX~i=softmax(dm/h QiKiT)Vi
其中, Q i , K i , V i , X ~ i ∈ R n × ( d m / h ) Q_i,K_i,V_i, \tilde{X}_i \in \mathbb{R}^{n \times (d_m/h)} Qi,Ki,Vi,X~iRn×(dm/h) X ~ i \tilde{X}_i X~i 只是一个中间量,但注意计算时候若有padding需进行mask操作。然后将所有head的输出合并起来
X ~ = [ X ~ 1 , . . . X ~ i , . . . , X ~ h ] ∈ R n × d m X [ 1 ] = X ~ W ∈ R n × d m \tilde{X} = [\tilde{X}_1,...\tilde{X}_i,...,\tilde{X}_h] \in \mathbb{R}^{n \times d_m} \\ X^{[1]} = \tilde{X} W \in \mathbb{R}^{n \times d_m} X~=[X~1,...X~i,...,X~h]Rn×dmX[1]=X~WRn×dm

注:程序中生成 Q i , K i , V i Q_i,K_i,V_i Qi,Ki,Vi的做法是直接对 Q , K , V Q,K,V Q,K,V进行划分。实在是不想用电脑画图了。
请添加图片描述

Encoder的后续

X [ 2 ] = X [ 1 ] + X X [ 3 ] = L a y e r N o r m ( X [ 2 ] ) X^{[2]} = X^{[1]} + X \\ X^{[3]} = LayerNorm( X^{[2]} ) X[2]=X[1]+XX[3]=LayerNorm(X[2])
几种Norm的对比
请添加图片描述
X [ 4 ] = F F N ( X [ 3 ] ) X^{[4]} = FFN( X^{[3]} ) X[4]=FFN(X[3])
其中FFN为Feed Forward Net,就是几个全连接层。然后使用ADD和Norm得到Encoder的输出 X [ 6 ] X^{[6]} X[6]

Decoder

在训练时, Y ˉ \bar{Y} Yˉ就是groundtruth, Y Y Y的生成方式可类比于 X X X
重点在于Masked-multi-head-attention的理解。

假设输出的句子有10个单词,在Decoder确定第4个单词的时候,它只能依赖前3个单词,输出第5个的时候只能依赖前4个,以此类推。为了完成这个效果需要一个下三角矩阵(对角线上方为0),作用在相关性计算的结果上

A ~ i o = Q i o ⋅ K i o T d m / h \tilde{A}_i^o = \frac{Q_i^o \cdot {K_i^o}^T}{\sqrt {d_m/h}} A~io=dm/h QioKioT
A ~ i o \tilde{A}_i^o A~io与mask元素相乘后可以达到上述效果

Y ~ i = s o f t m a x ( Q i o ⋅ K i o T d m / h ⊙ m a s k T ) ⋅ V i o \tilde{Y}_i = softmax(\frac{Q_i^o \cdot {K_i^o}^T}{\sqrt {d_m/h}} \odot mask^T ) \cdot V_i^o Y~i=softmax(dm/h QioKioTmaskT)Vio
这是以多头的形式写的公式,mask为下三角矩阵, Q i o , K i o , V i o Q_i^o , K_i^o ,V_i^o Qio,Kio,Vio根据 Y Y Y生成。

参考

  1. 细节:Attention is all your need-transformer
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值