超平实版Pytorch Self-Attention: 参数详解(尤其是mask)(使用nn.MultiheadAttention)

Self-Attention的结构图

本文侧重于Pytorch中对self-attention的具体实践,具体原理不作大量说明,self-attention的具体结构请参照下图。
在这里插入图片描述
(图中为输出第二项attention output的情况,k与q为key、query的缩写)

本文中将使用Pytorch的torch.nn.MultiheadAttention来实现self-attention.

forward输入中的query、key、value

首先,前三个输入是最重要的部分query、key、value。由图1可知,我们self-attention的这三样东西其实是一样的,它们的形状都是:(L,N,E) 1

L:输入sequence的长度(例如一个句子的长度)
N:批大小(例如一个批的句子个数)
E:词向量长度

forward的输出

输出的内容很少只有两项:

  1. attn_output
    即通过self-attention之后,从每一个词语位置输出来的attention。其形状为(L,N,E),是和输入的query它们形状一样的。因为毕竟只是给value乘了一个weight。

  2. attn_output_weights
    即attention weights,形状是(N,L,L),因为每一个单词和任意另一个单词之间都会产生一个weight,所以每一句句子的weight数量是L*L

实例化一个nn.MultiheadAttention

这里对MultiheadAttention进行一个实例化并传入一些参数,实例化之后我们得到的东西我们就可以向它传入input了。

实例化时的代码:

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

其中,embed_dim是每一个单词本来的词向量长度;num_heads是我们MultiheadAttention的head的数量。

pytorch的MultiheadAttention应该使用的是Narrow self-attention机制,即,把embedding分割成num_heads份,每一份分别拿来做一下attention。

也就是说:单词1的第一份、单词2的第一份、单词3的第一份…会当成一个sequence,做一次我们图1所示的self-attention。
然后,单词1的第二份、单词2的第二份、单词3的第二份…也会做一次
直到单词1的第num_heads份、单词2的第num_heads份、单词3的第num_heads份…也做完self-attention

从每一份我们都会得到一个(L,N,E)形状的输出,我们把这些全部concat在一起,会得到一个(L,N,E*num_heads)的张量。

这时候,我们拿一个矩阵,把这个张量的维度变回(L,N,E)即可输出。

进行forward操作

我们把我们刚才实例化好的multihead_attn拿来进行forward操作(即输入input得到output):

attn_output, attn_output_weights = multihead_attn(query, key, value)

关于mask

mask可以理解成遮罩、面具,作用是帮助我们“遮挡”掉我们不需要的东西,即让被遮挡的东西不影响我们的attention过程。

在forward的时候,有两个mask参

  • 105
    点赞
  • 248
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值