pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

本文将对 Scaled Dot-Product Attention,Multi-head attentionSelf-attentionTransformer等概念做一个简要介绍和区分。最后对通用的 Multi-head attention 进行代码实现和应用

一、概念:

1. Scaled Dot-Product Attention

在实际应用中,经常会用到 Attention 机制,其中最常用的是Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。

  • Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
  • Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
  • Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.

2. Multi-head attention

是在 Scaled Dot-Product Attention 的基础上,分成多个头,也就是有多个Q、K、V并行进行计算attention,可能侧重与不同的方面的相似度和权重。

3. Self-attention

自注意力机制 是在Scaled Dot-Product Attention 以及Multi-head attention的基础上的一种应用场景,就是指 QKV的来源是相同的自己和自己计算attention,类似于经过一个线性层等,输入输出等长。

如果QKV的来源是不同的,不能叫做 self-attention,只能是attention。比如GST中的KV是随机初始化的多个token,而Q是reference encoder得到的梅尔谱的一帧。同理,Q也可以是随机初始化的一个,而KV是来自于输入,这样就可以将某一变长长度为N的输入计算attention得到一个长度为1的向量。

4. Transformer

Transformer 是指 在Scaled Dot-Product Attention 以及Multi-head attention以及Self-attention的基础上的一种通用的模型框架,它包括Positional Encoding,Encoder,Decoder等等。Transformer不等于Self-attention。

二、代码实现

 平时经常会用到Attention操作,接下来对Multi-head Attention 进行代码整理和实现,方便以后可以直接调用接口,其中单头注意力机制作为其中的一种特殊情况。

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    '''
    input:
        query --- [N, T_q, query_dim] 
        key --- [N, T_k, key_dim]
        mask --- [N, T_k]
    output:
        out --- [N, T_q, num_units]
        scores -- [h, N, T_q, T_k]
    '''

    def __init__(self, query_dim, key_dim, num_units, num_heads):

        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)

    def forward(self, query, key, mask=None):
        querys = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key)  # [N, T_k, num_units]
        values = self.W_value(key)

        split_size = self.num_units // self.num_heads
        querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0)  # [h, N, T_q, num_units/h]
        keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]
        values = torch.stack(torch.split(values, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]

        ## score = softmax(QK^T / (d_k ** 0.5))
        scores = torch.matmul(querys, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores = scores / (self.key_dim ** 0.5)

        ## mask
        if mask is not None:
            ## mask:  [N, T_k] --> [h, N, T_q, T_k]
            mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1)
            scores = scores.masked_fill(mask, -np.inf)
        scores = F.softmax(scores, dim=3)

        ## out = score * V
        out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
        out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0)  # [N, T_q, num_units]

        return out,scores

 三、实际应用:

1. 接口调用:

## 类实例化
attention = MultiHeadAttention(3,4,5,1)

## 输入
qurry = torch.randn(8, 2, 3)
key = torch.randn(8, 6 ,4)
mask = torch.tensor([[False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],])

## 输出
out, scores = attention(qurry, key, mask)
print('out:', out.shape)         ## torch.Size([8, 2, 5])
print('scores:', scores.shape)   ## torch.Size([1, 8, 2, 6])

2. mask的作用:

mask之前的 scores:

mask之后的 scores:

softmax之后的scores:

  • 39
    点赞
  • 147
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
多头注意力机制是 Transformer 模型的重要组成部分,它允许模型在不同的注意力头上学习不同的特征表示。下面是 PyTorch 应用多头注意力机制代码示例: ```python import torch import torch.nn as nn class MultiheadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiheadAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model self.d_k = d_model // num_heads self.linear_q = nn.Linear(d_model, d_model) self.linear_k = nn.Linear(d_model, d_model) self.linear_v = nn.Linear(d_model, d_model) self.linear_out = nn.Linear(d_model, d_model) def forward(self, query, key, value, mask=None): batch_size = query.size(0) # Linear projections query = self.linear_q(query) key = self.linear_k(key) value = self.linear_v(value) # Reshape tensors query = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (batch_size, num_heads, seq_length_q, d_k) key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (batch_size, num_heads, seq_length_k, d_k) value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # (batch_size, num_heads, seq_length_v, d_k) # Scaled dot-product attention scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32)) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # Apply mask to scores attention_weights = torch.softmax(scores, dim=-1) context = torch.matmul(attention_weights, value) # Reshape and concatenate attention heads context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k) # Final linear layer output = self.linear_out(context) return output, attention_weights ``` 上述代码定义了一个 `MultiheadAttention` 类,该类继承自 PyTorch 的 `nn.Module`。在 `forward` 方法,首先进行线性变换,然后对输入的 query、key 和 value 进行维度的调整。接下来,通过矩阵乘法计算注意力分数,如果需要的话,可以应用屏蔽(mask)操作。然后使用 softmax 函数计算注意力权重,并将其与 value 相乘得到上下文张量。最后,对上下文张量进行维度调整和线性变换得到最终输出。 你可以根据自己的需求使用这个多头注意力机制模块来构建自己的 Transformer 模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值