import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(SelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.q_linear = nn.Linear(embed_dim, embed_dim)
self.k_linear = nn.Linear(embed_dim, embed_dim)
self.v_linear = nn.Linear(embed_dim, embed_dim)
self.output_linear = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.size()
# Split embeddings into heads
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
x = x.permute(0, 2, 1, 3)
# Shape: [batch_size, num_heads, seq_len, head_dim]
# Linear projections for query, key, and value
q = self.q_linear(x)
k = self.k_linear(x)
v = self.v_linear(x)
# Shape: [batch_size, num_heads, seq_len, head_dim]
# Compute dot product attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
# Shape: [batch_size, num_heads, seq_len, seq_len]
# Apply softmax to get attention weights
attention_weights = torch.softmax(scores, dim=-1)
# Shape: [batch_size, num_heads, seq_len, seq_len]
# Apply attention weights to value
attention_output = torch.matmul(attention_weights, v)
# Shape: [batch_size, num_heads, seq_len, head_dim]
# Merge heads
attention_output = attention_output.permute(0, 2, 1, 3)
# Shape: [batch_size, seq_len, num_heads, head_dim]
attention_output = attention_output.contiguous().view(batch_size, seq_len, embed_dim)
# Shape: [batch_size, seq_len, embed_dim]
# Apply output linear layer
attention_output = self.output_linear(attention_output)
# Shape: [batch_size, seq_len, embed_dim]
return attention_output
这个模块接受输入张量 x,大小为 [batch_size, seq_len, embed_dim],其中 batch_size 表示批量大小,seq_len 表示序列长度,embed_dim 表示嵌入维度。模块输出大小相同的张量,表示输入的自注意力表示。
该模块执行以下步骤:
将输入张量划分为 num_heads 个头,每个头大小为 head_dim = embed_dim / num_heads。
使用线性投影将每个头的嵌入表示为查询、键和值。
对于每个头,计算注意力分数(点积注意力)并将其归一化为注意力权重。
对于每个头,将注意力权重应用于值,然后将每个头的输出合