Transformer
是由 Google 提出的用于自然语言处理(NLP)任务的一个深度学习模型架构,它基于自注意力(self-attention)机制。以下是一个简化的 Transformer
编码器(Encoder)和解码器(Decoder)的 PyTorch 代码示例。请注意,这个示例仅用于教学目的,并未包含完整的 Transformer
架构(如位置编码、层归一化、残差连接等)。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
# 这是一个简化的多头注意力机制实现,用于演示
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.dense = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3)
def forward(self, v, k, q, mask):
batch_size = q.shape[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# Scaled Dot-Product Attention
scaled_attention_logits = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32))
if mask is not None:
scaled_attention_logits += (mask * -1e9) # Add the mask to the scaled tensor.
attention_weights = nn.Softmax(dim=-1)(scaled_attention_logits) # (batch_size, num_heads, seq_len_q, seq_len_k)
output = torch.matmul(attention_weights, v) # (batch_size, num_heads, seq_len_q, depth)
output = output.permute(0, 2, 1, 3).contiguous() # (batch_size, seq_len_q, num_heads, depth)
output = output.reshape(batch_size, -1, self.d_model) # (batch_size, seq_len_q, d_model)
output = self.dense(output) # (batch_size, seq_len_q, d_model)
return output, attention_weights
# 这里省略了 Transformer 的其他组件,如前馈神经网络、位置编码、层归一化等。
# 如果你想要完整的 Transformer 编码器或解码器,你还需要实现这些组件,并把它们组合在一起。
注意:上面的 MultiHeadAttention
类仅实现了多头注意力机制的核心部分,并没有包含完整的 Transformer
编码器或解码器。一个完整的 Transformer
编码器通常包括一个多头注意力层、一个前馈神经网络(FFN)以及可能的层归一化和残差连接。解码器则通常包括两个多头注意力层(一个自注意力层和一个编码器-解码器注意力层)以及一个前馈神经网络。