pytorch笔记:nn.MultiheadAttention

1 函数介绍

torch.nn.MultiheadAttention(
    embed_dim, 
    num_heads, 
    dropout=0.0, 
    bias=True, 
    add_bias_kv=False, 
    add_zero_attn=False, 
    kdim=None, 
    vdim=None, 
    batch_first=False, 
    device=None, 
    dtype=None)

2 参数介绍

embed_dim模型的维度
num_heads

attention的头数

(embed_dim会平均分配给每个头,也即每个头的维度是embed_dim//num_heads)

dropoutattn_output_weights的dropout概率
biasinput和output的投影函数,是否有bias
kdim

k的维度,默认embed_dim

vdimv的维度,默认embed_dim
batch_firstTrue——输入和输出的维度是(batch_num,seq_len,feature_dim)
False——输入和输出的维度是(batch_num,seq_len,feature_dim)

3 forward函数

forward(
    query, 
    key, 
    value, 
    key_padding_mask=None, 
    need_weights=True, 
    attn_mask=None, 
    average_attn_weights=True)

4 forward函数参数介绍

query
  • 对于没有batch的输入,维度是(length,embed_dim)
  • 对于有batch的输入,维度是(batch_num,len,embed_dim)或者(len,batch_num,embed_dim)【取决于batch_first】
key
  • 对于没有batch的输入,维度是(S_length,kdim)
  • 对于有batch的输入,维度是(batch_num,len,kdim)或者(len,batch_num,kdim)【取决于batch_first】
value
  • 对于没有batch的输入,维度是(S_length,vdim)
  • 对于有batch的输入,维度是(batch_num,len,vdim)或者(len,batch_num,vdim)【取决于batch_first】
key_padding_mask 

如果设置,那么

  • 对于没有batch的输入,这需要一个S_length大小的mask向量
  • 对于有batch的输入,这需要一个(length,S_length)大小的mask矩阵

True表示对应的key value在计算attention的时候,需要被忽略

need_weights如果设置,那么返回值会多一个attn_output_weight
attn_maskTrue表示对应的attention value 不应该存在
average_attn_weights 

如果设置,那么返回的是各个头的平均attention weight

否则,就是把所有的head分别输出

5 forward输出

attn_output
  • 对于没有batch的输入,维度为(length,embed_dim)
  • 对于有batch的输入,维度为(length,batch_size,embed_dim)或(batch_size,length,embed_dim)
attn_output_weight
  • 对于没有batch的输入
    • 如果average_attn_weights为True,那么就是(length,S_length);否则是(num_heads,length,S_length)

6 举例

import torch
import torch.nn as nn
lst=torch.Tensor([[1,2,3,4],
                [2,3,4,5],
                 [7,8,9,10]])
lst=lst.unsqueeze(1)
lst.shape
#torch.Size([3, 1, 4])


multi_atten=nn.MultiheadAttention(embed_dim=4,
                                  num_heads=2)
multi_atten(lst,lst,lst)
'''
(tensor([[[ 1.9639, -3.7282,  2.1215,  0.6630]],
 
         [[ 2.2423, -4.2444,  2.2466,  1.0711]],
 
         [[ 2.3823, -4.5058,  2.3015,  1.2964]]], grad_fn=<AddBackward0>),
 tensor([[[9.0335e-02, 1.2198e-01, 7.8769e-01],
          [2.6198e-02, 4.4854e-02, 9.2895e-01],
          [1.6031e-05, 9.4658e-05, 9.9989e-01]]], grad_fn=<DivBackward0>))
'''

### PyTorch `nn.MultiheadAttention` 类详解 #### 参数说明 `nn.MultiheadAttention` 是用于实现多头注意力机制的关键组件。该层接受多个输入张量并返回加权后的特征表示。 主要参数如下: - **embed_dim**: 表示每个位置的嵌入维度大小,即查询、键和值向量的最后一个维度。 - **num_heads**: 多头注意力中的头数。此数值决定了如何分割嵌入空间来计算不同的注意力分布[^1]。 - **dropout (optional)**: 默认为 0,在 Softmax 后应用 Dropout 层的概率。 - **bias (optional)**: 如果设为 False,则线性投影层不会带有偏置项,默认为 True。 - **add_bias_kv (optional)**: 若设置为True,则会额外添加可学习的偏差到键和值上,默认False。 - **add_zero_attn (optional)**: 是否加入全零的虚拟注意力权重,默认False。 - **kdim 和 vdim (optional)**: 键和值的特征维度;如果未指定则默认等于 embed_dim。 #### 使用实例 下面是一个简单的使用案例,展示了如何创建一个多头自注意机制,并将其应用于给定的数据集。 ```python import torch import torch.nn as nn class MultiHeadedSelfAttention(nn.Module): def __init__(self, d_model=512, n_heads=8): super(MultiHeadedSelfAttention, self).__init__() self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads) def forward(self, query, key, value): attn_output, _ = self.attn(query=query, key=key, value=value) return attn_output # 创建模型实例 d_model = 512 n_heads = 8 batch_size = 32 seq_len = 10 model = MultiHeadedSelfAttention(d_model=d_model, n_heads=n_heads) # 构建随机输入数据 query = torch.randn((seq_len, batch_size, d_model)) key = torch.randn((seq_len, batch_size, d_model)) value = torch.randn((seq_len, batch_size, d_model)) output = model.forward(query=query, key=key, value=value) print(output.shape) # 输出形状应为(seq_len, batch_size, d_model)[^2] ``` 上述代码定义了一个基于 `MultiheadAttention` 的简单自注意力网络结构,并通过构建一批假数据进行了前向传播操作验证其功能正常工作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值