目录
Self-Attention的结构图
本文侧重于Pytorch中对self-attention的具体实践,具体原理不作大量说明,self-attention的具体结构请参照下图。
(图中为输出第二项attention output的情况,k与q为key、query的缩写)
本文中将使用Pytorch的torch.nn.MultiheadAttention来实现self-attention.
forward输入中的query、key、value
首先,前三个输入是最重要的部分query、key、value。由图1可知,我们self-attention的这三样东西其实是一样的,它们的形状都是:(L,N,E) 1。
L:输入sequence的长度(例如一个句子的长度)
N:批大小(例如一个批的句子个数)
E:词向量长度
forward的输出
输出的内容很少只有两项:
-
attn_output
即通过self-attention之后,从每一个词语位置输出来的attention。其形状为(L,N,E),是和输入的query它们形状一样的。因为毕竟只是给value乘了一个weight。 -
attn_output_weights
即attention weights,形状是(N,L,L),因为每一个单词和任意另一个单词之间都会产生一个weight,所以每一句句子的weight数量是L*L
实例化一个nn.MultiheadAttention
这里对MultiheadAttention进行一个实例化并传入一些参数,实例化之后我们得到的东西我们就可以向它传入input了。
实例化时的代码:
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
其中,embed_dim是每一个单词本来的词向量长度;num_heads是我们MultiheadAttention的head的数量。
pytorch的MultiheadAttention应该使用的是Narrow self-attention机制,即,把embedding分割成num_heads份,每一份分别拿来做一下attention。
也就是说:单词1的第一份、单词2的第一份、单词3的第一份…会当成一个sequence,做一次我们图1所示的self-attention。
然后,单词1的第二份、单词2的第二份、单词3的第二份…也会做一次
直到单词1的第num_heads份、单词2的第num_heads份、单词3的第num_heads份…也做完self-attention
从每一份我们都会得到一个(L,N,E)形状的输出,我们把这些全部concat在一起,会得到一个(L,N,E*num_heads)的张量。
这时候,我们拿一个矩阵,把这个张量的维度变回(L,N,E)即可输出。
进行forward操作
我们把我们刚才实例化好的multihead_attn拿来进行forward操作(即输入input得到output):
attn_output, attn_output_weights = multihead_attn(query, key, value)
关于mask
mask可以理解成遮罩、面具,作用是帮助我们“遮挡”掉我们不需要的东西,即让被遮挡的东西不影响我们的attention过程。
在forward的时候,有两个mask参