官方文档链接:MultiheadAttention — PyTorch 1.12 documentation
目录
多注意头原理
MultiheadAttention,翻译成中文即为多注意力头,是由多个单注意头拼接成的
它们的样子分别为:👇
单头注意力的图示如下:

整体称为一个单注意力头,因为运算结束后只对每个输入产生一个输出结果,一般在网络中,输出可以被称为网络提取的特征,那我们肯定希望提取多种特征,[ 比如说我输入是一个修狗狗图片的向量序列,我肯定希望网络提取到特征有形状、颜色、纹理等等,所以单次注意肯定是不够的 ]
于是最简单的思路,最优雅的方式就是将多个头横向拼接在一起,每次运算我同时提到多个特征,所以多头的样子如下:

其中的紫色长方块(Scaled Dot-Product Attention)就是上一张单注意力头,内部结构没有画出,如果拼接h个单注意力头,摆放位置就如图所示。
因为是拼接而成的,所以每个单注意力头其实是各自输出各自的,所以会得到h个特征,把h个特征拼接起来,就成为了多注意力的输出特征。
pytorch的多注意头
首先可以看出我们调用的时候,只要写torch.nn.MultiheadAttention就好了,比如👇
import torch
import torch.nn as n
# 先决定参数
dims = 256 * 10 # 所有头总共需要的输入维度
heads = 10 # 单注意力头的总共个数
dropout_pro = 0.0 # 单注意力头
# 传入参数得到我们需要的多注意力头
layer = torch.nn.MultiheadAttention(embed_dim = dims, num_heads = heads, dropout = dropout_pro)
解读 官方给的参数解释:
embed_dim - Total dimension of the model 模型的总维度(总输入维度)
所以这里应该输入的是每个头输入的维度×头的数量
num_heads - Number of parallel attention heads. Note that embed_dim
will be split across num_heads
(i.e. each head will have dimension embed_dim // num_heads
).
num_heads即为注意头的总数量
注意看括号里的这句话,每个头的维度为 embed_dim除num_heads
也就是说,如果我的词向量的维度为n,(注意不是序列的维度),我准备用m个头提取序列的特征,则embed_dim这里的值应该是n×m,num_heads的值为m。
【更