torch 实现 multihead attention

1. 从简单的开始: self attn的实现

# import torch.nn
from torch import nn
from math import sqrt

# torch.nn.TransformerEncoder

class Self_Attention(nn.Module):
    def __init__(self, model_dim):         # X: batch_size * seq_len * input_dim
        '''
        :param dim: transformer 内变量维度
        '''
        super(Self_Attention,self).__init__()     # Wk/ Wq/ Wv: input_dim * dim_k
        self.model_dim = model_dim
        self.q = nn.Linear(model_dim, model_dim)
        self.k = nn.Linear(model_dim, model_dim)
        self.v = nn.Linear(model_dim, model_dim)
        self.softmax = nn.Softmax(dim=-1)
        # self.norm = 1/ sqrt(dim_k)

    def forwad(self, X):
        Q, K, V = self.q(X), self.k(X), self.v(X)   # Q, K, V: batch_size * seq_len * dim_k
        Attn = self.softmax(torch.bmm(Q, K.transpose(2, 1))/sqrt(self.model_dim))
        O = torch.bmm(Attn, V)                      # O: batch_size * seq_len * dim_k
        return O

2. multihead attention

import torch
from torch import nn
from math import sqrt
import numpy as np
import copy

class dot_attention(nn.Module):
    """
    点积注意力机制
    """
    def __init__(self, attention_dropout=0.0):      # 块内 dropout; 一般不设置
        super(dot_attention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(attention_dropout)
    def forward(self, Q, K, V, attn_mask=None):
        scale = Q.shape[-1]
        Attn = torch.bmm(Q, K.transpose(2,1))/sqrt(scale)
        if attn_mask:
            Attn = Attn.masked_fill(attn_mask, -np.inf)
        Attn = self.softmax(Attn)
        Attn = self.dropout(Attn)
        O = torch.bmm(Attn, V)
        return Attn, O

class MultiHeadAttention(nn.Module):
    """
    点积注意力机制
    """
    def __init__(self, model_dim, num_heads, dropout= 0.0):
        self.dim_per_head = model_dim// num_heads
        self.num_heads = num_heads
        self.model_dim = model_dim

        self.k = nn.Linear(model_dim, self.dim_per_head*num_heads)
        self.q = nn.Linear(model_dim, self.dim_per_head*num_heads)
        self.v = nn.Linear(model_dim, self.dim_per_head*num_heads)
        self.dot_attn = dot_attention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(model_dim)

    def forward(self, X, attn_mask=None):      # X: batch_size * seq_len * dim
        residual = X
        batch_size = X.shape[0]
        if attn_mask:
            attn_mask = attn_mask.repeat(self.num_heads, 1, 1)

        K, Q, V = self.k(X), self.q(X), self.v(X)
        Q = Q.view(batch_size*self.num_heads, -1, self.dim_per_head)
        K = K.view(batch_size*self.num_heads, -1, self.dim_per_head)
        V = V.view(batch_size*self.num_heads, -1, self.dim_per_head)

        Attn, O = self.dot_attn(Q, K, V, attn_mask)
        O.view(batch_size, -1, self.model_dim)
        O = self.linear_final(O)
        O = self.dropout(O)
        O = self.norm(O+residual)
        return O




if __name__ == '__main__':
    Q = torch.rand((1,5,512))
    K = torch.rand((1,10,512))
    V = copy.deepcopy(K)
    cal_attn = dot_attention()
    Attn, O = cal_attn(Q, K, V)
    print("Attn:", Attn)
    print("O:", O)

3. 无 attn_mask 无 dropout 的简化版

class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim, head_nums):
        super().__init__(MultiHeadAttention, self)
        assert model_dim % head_nums == 0
        self.model_dim = model_dim
        self.head_nums = head_nums
        self.dim_per_head = model_dim / head_nums
        self.linear_k = nn.Linear(model_dim, model_dim)
        self.linear_q = nn.Linear(model_dim, model_dim)
        self.linear_v = nn.Linear(model_dim, model_dim)
        self.softmax = nn.Softmax(-1)
        self.fc = nn.Linear(model_dim, model_dim)

    def forward(self, X):                   # X: batch_size * seq_len * dim
        bsz = X.shape[0]
        K, Q, V = self.linear_k(X), self.linear_q(X), self.linear_v(X)
        K = K.view(bsz*self.head_nums, -1, self.dim_per_head)
        Q = Q.view(bsz*self.head_nums, -1, self.dim_per_head)
        V = V.view(bsz*self.head_nums, -1, self.dim_per_head)
        Attn = self.softmax(torch.bmm(Q, K.transpse(2,1))/sqrt(self.dim_per_head))
        O = torch.bmm(Attn, V)
        O = O.view(bsz, -1, self.model_dim)
        O = self.fc(O)
        return O

4.浅贴一个文心一言的 multihead attn的输出代码

import torch
import torch.nn as nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadSelfAttention, self).__init__()

        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.fc_out = nn.Linear(d_model, d_model)

        self.attention_weights = None

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        return x.transpose(2, 1)

    def combine_heads(self, x):
        x = x.transpose(2, 1).contiguous().view(x.size(0), -1)
        return x

    def forward(self, inputs):
        batch_size = inputs.size(0)
        q = self.query_linear(inputs)
        k = self.key_linear(inputs)
        v = self.value_linear(inputs)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        self.attention_weights = nn.Softmax(dim=-1)(scores)

        attended = torch.matmul(self.attention_weights, v)
        attended = self.combine_heads(attended)

        output = self.fc_out(attended)
        return output

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值