手撕多头自注意力MultiHeadAttention

一、目录

  1. 代码

二、实现

  1. 代码
import torch
import torch.nn as nn
import numpy as np


'''手撕多头自注意力'''

class MultiHeadAttention(nn.Module):
    def __init__(self,input_dim,heads,d_model,dropout=0.1):
        super(MultiHeadAttention,self).__init__()
        self.d_model=d_model
        self.heads=heads
        self.input_dim=input_dim
        self.d_k=d_model//heads

        self.linear_q=nn.Linear(self.input_dim,self.d_model)
        self.linear_k=nn.Linear(self.input_dim,self.d_model)
        self.linear_v=nn.Linear(self.input_dim,self.d_model)

        self.dropout=nn.Dropout(dropout)
        self.fact=1/np.sqrt(self.d_k)

        self.out=nn.Linear(d_model,d_model)


    def forward(self,x,mask=None):
        batch_size,seq_len,hidden_size=x.shape
        q=self.linear_q(x).view(batch_size,-1,self.heads,self.d_k)
        k=self.linear_k(x).view(batch_size,-1,self.heads,self.d_k)
        v=self.linear_v(x).view(batch_size,-1,self.heads,self.d_k)

        q=q.transpose(1,2)
        k=k.transpose(1,2)
        v=v.transpose(1,2)   #batch,head,seq_len,d_k

        #求注意力
        score=torch.matmul(q,k.transpose(-2,-1))*self.fact
        if mask is not None:
            score=score+mask

        att=torch.softmax(score,dim=-1)
        if self.dropout is not None:
            att=self.dropout(att)
        output=torch.matmul(att,v)    #(batch,head,sel_len,d_k)

        #拼接
        concat=output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)

        output=self.out(concat)
        return output





if __name__ == '__main__':
    pass
    batch=2
    seq_len=5
    input_dim=32
    head=2
    d_model=32
    x=torch.randn(size=(batch,seq_len,input_dim))
    attention=MultiHeadAttention(input_dim,head,d_model)
    print(attention(x).shape)
    #求掩码
    attention_mask=torch.tril(torch.ones(size=(seq_len,seq_len),dtype=torch.bool)).view(1, 1, seq_len, seq_len)
    attention_mask=attention_mask.to(dtype=torch.float16)
    attention_mask= (1.0-attention_mask)*torch.finfo(torch.float16).min
    print(attention_mask)
    print(attention(x).shape)






  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值