Code-带掩码的多头自注意力

import torch.nn as nn
import torch
import math

class Mask_multi_head_self_attention(nn.Module):
  def __init__(self, n_heads, d_model):
    super().__init__()
    self.n_heads = n_heads
    self.d_model = d_model
    # 映射到Q K V 空间
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    # 因为多头,所以最后需要一个映射
    self.w_combine = nn.Linear(d_model, d_model)
    self.softmax = nn.Softmax(dim=-1)
  
  def forward(self, x):
    b,seq_len,dim = x.shape
    head_dim = self.d_model // self.n_heads
    q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)
    q = q.view(b, seq_len, self.n_heads, head_dim).permute(0,2,1,3)
    k = k.view(b, seq_len, self.n_heads, head_dim).permute(0,2,1,3)
    v = v.view(b, seq_len, self.n_heads, head_dim).permute(0,2,1,3)

    score = (q @ k.transpose(-1,-2)) / math.sqrt(head_dim)
    mask = torch.tril(torch.ones(seq_len, seq_len)) # 下三角矩阵
    mask = torch.where(mask==0, float('-inf'), 0)
    score = self.softmax(score+mask) @ v
    score = score.permute(0,2,1,3).contiguous().view(b, seq_len, -1)

    out = self.w_combine(score)
    return out

d_model=512
n_head=8
x=torch.rand(5, 100, 512) # b x seq_len x dim
model = Mask_multi_head_self_attention(n_head, d_model)
out = model(x)
print(out.shape)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值