import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def init(self, d_model, num_heads):
super(MultiHeadAttention, self).init()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.dense = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3)
def forward(self, q, k, v, mask):
batch_size = q.size(0)
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.depth)
if mask is not None:
scaled_attention += (mask * -1e9)
attention_weights = nn.Softmax(dim=-1)(scaled_attention)
output = torch.matmul(attention_weights, v)
output = output.permute(0, 2, 1, 3).contiguous()
output = output.view(batch_size, -1, self.d_model)
return self.dense(output)
class FeedForward(nn.Module):
def init(self, d_model, d_ff):
super(FeedForward, self).init()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))class EncoderLayer(nn.Module):
def init(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).init()
self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.layernorm1 = nn.LayerNorm(d_model)
self.layernorm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
attn_output = self.multi_head_attention(x, x, x, mask)
out1 = self.layernorm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(out1)
out2 = self.layernorm2(out1 + self.dropout(ff_output))
return out2class Transformer(nn.Module):
def init(self, num_layers, d_model, num_heads, d_ff, input_vocab_size,
target_vocab_size, max_seq_length, dropout):
super(Transformer, self).init()
self.embedding = nn.Embedding(input_vocab_size, d_model)
self.pos_encoding = self.positional_encoding(max_seq_length, d_model)
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)])
self.final_layer = nn.Linear(d_model, target_vocab_size)
self.dropout = nn.Dropout(dropout)
def positional_encoding(self, max_seq_length, d_model):
pos_encoding = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pos_encoding[:, 0::2] = torch.sin(position * div_term)
pos_encoding[:, 1::2] = torch.cos(position * div_term)
pos_encoding = pos_encoding.unsqueeze(0)
return pos_encoding
def forward(self, x, mask):
seq_length = x.size(1)
x = self.embedding(x) * math.sqrt(self.d_model)
x += self.pos_encoding[:, :seq_length, :]
x = self.dropout(x)
for layer in self.encoder_layers:
x = layer(x, mask)
return self.final_layer(x)使用示例
input_vocab_size = 1000
target_vocab_size = 1000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1
model = Transformer(num_layers, d_model, num_heads, d_ff, input_vocab_size,
target_vocab_size, max_seq_length, dropout)
假设输入是一个batch_size为32,序列长度为50的整数张量
x = torch.randint(0, input_vocab_size, (32, 50))
mask = torch.ones(32, 50, 50)
output = model(x, mask)
print(output.shape) # 应该输出 torch.Size([32, 50, 1000])
这个代码实现了Transformer的主要组件:
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.