Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析

官方文档链接:MultiheadAttention — PyTorch 1.12 documentation

目录

多注意头原理

pytorch的多注意头

解读 官方给的参数解释:

多注意头的pytorch使用

完整的使用代码


多注意头原理

MultiheadAttention,翻译成中文即为多注意力头,是由多个单注意头拼接成的

它们的样子分别为:👇

        单头注意力的图示如下:

单注意力头
​​ 

        整体称为一个单注意力头,因为运算结束后只对每个输入产生一个输出结果,一般在网络中,输出可以被称为网络提取的特征,那我们肯定希望提取多种特征,[ 比如说我输入是一个修狗狗图片的向量序列,我肯定希望网络提取到特征有形状、颜色、纹理等等,所以单次注意肯定是不够的 ]

        于是最简单的思路,最优雅的方式就是将多个头横向拼接在一起,每次运算我同时提到多个特征,所以多头的样子如下:

多注意力头
多注意力头

        其中的紫色长方块(Scaled Dot-Product Attention)就是上一张单注意力头,内部结构没有画出,如果拼接h个单注意力头,摆放位置就如图所示。

        因为是拼接而成的,所以每个单注意力头其实是各自输出各自的,所以会得到h个特征,把h个特征拼接起来,就成为了多注意力的输出特征。


pytorch的多注意头

        

首先可以看出我们调用的时候,只要写torch.nn.MultiheadAttention就好了,比如👇

import torch
import torch.nn as n

# 先决定参数
dims = 256 * 10 # 所有头总共需要的输入维度
heads = 10    # 单注意力头的总共个数
dropout_pro = 0.0 # 单注意力头

# 传入参数得到我们需要的多注意力头
layer = torch.nn.MultiheadAttention(embed_dim = dims, num_heads = heads, dropout = dropout_pro)

解读 官方给的参数解释:

embed_dim - Total dimension of the model 模型的总维度(总输入维度)

        所以这里应该输入的是每个头输入的维度×头的数量

num_heads - Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

        num_heads即为注意头的总数量        

        注意看括号里的这句话,每个头的维度为 embed_dim除num_heads

        也就是说,如果我的词向量的维度为n,(注意不是序列的维度),我准备用m个头提取序列的特征,则embed_dim这里的值应该是n×m,num_heads的值为m。

【更

评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值