自注意力模块

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        
        # Split embeddings into heads
        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
        x = x.permute(0, 2, 1, 3)
        # Shape: [batch_size, num_heads, seq_len, head_dim]
        
        # Linear projections for query, key, and value
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        # Shape: [batch_size, num_heads, seq_len, head_dim]
        
        # Compute dot product attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        # Shape: [batch_size, num_heads, seq_len, seq_len]
        
        # Apply softmax to get attention weights
        attention_weights = torch.softmax(scores, dim=-1)
        # Shape: [batch_size, num_heads, seq_len, seq_len]
        
        # Apply attention weights to value
        attention_output = torch.matmul(attention_weights, v)
        # Shape: [batch_size, num_heads, seq_len, head_dim]
        
        # Merge heads
        attention_output = attention_output.permute(0, 2, 1, 3)
        # Shape: [batch_size, seq_len, num_heads, head_dim]
        attention_output = attention_output.contiguous().view(batch_size, seq_len, embed_dim)
        # Shape: [batch_size, seq_len, embed_dim]
        
        # Apply output linear layer
        attention_output = self.output_linear(attention_output)
        # Shape: [batch_size, seq_len, embed_dim]
        
        return attention_output

这个模块接受输入张量 x,大小为 [batch_size, seq_len, embed_dim],其中 batch_size 表示批量大小,seq_len 表示序列长度,embed_dim 表示嵌入维度。模块输出大小相同的张量,表示输入的自注意力表示。

该模块执行以下步骤:

将输入张量划分为 num_heads 个头,每个头大小为 head_dim = embed_dim / num_heads。
使用线性投影将每个头的嵌入表示为查询、键和值。
对于每个头,计算注意力分数(点积注意力)并将其归一化为注意力权重。
对于每个头,将注意力权重应用于值,然后将每个头的输出合

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ermzdy2

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值