python手撕代码——完整的transformer代码
1.位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super(PositionalEncoding, self).__init__()
self.encoding = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
self.encoding[:, 0::2] = torch.sin(position * div_term)
self.encoding[:, 1::2] = torch.cos(position * div_term)
self.encoding = self.encoding.unsqueeze(0)
def forward(self, x):
return self.encoding[:, :x.size(1)]
2.编码器
class EncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super(EncoderLayer, self).__init__()
self.multihead_attention = nn.MultiheadAttention(d_model, nhead)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None):
# 多头注意力
attn_output, _ = self.multihead_attention(src, src, src, attn_mask=src_mask)
src = src + self.dropout(attn_output)
src = self.layer_norm1(src)
# 全连接
ff_output = self.feed_forward(src)
src = src + self.dropout(ff_output)
src = self.layer_norm2(src)
return src
3.解码器
class DecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super(DecoderLayer, self).__init__()
self.multihead_attention1 = nn.MultiheadAttention(d_model, nhead)
self.multihead_attention2 = nn.MultiheadAttention(d_model, nhead)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.layer_norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
# masked Multi-head attention
attn_output1, _ = self.multihead_attention1(tgt, tgt, tgt, attn_mask=tgt_mask)
tgt = tgt + self.dropout(attn_output1)
tgt = self.layer_norm1(tgt)
# 接受编码器输出的 Multi-head attention
attn_output2, _ = self.multihead_attention2(tgt, memory, memory, attn_mask=memory_mask)
tgt = tgt + self.dropout(attn_output2)
tgt = self.layer_norm2(tgt)
#全连接
ff_output = self.feed_forward(tgt)
tgt = tgt + self.dropout(ff_output)
tgt = self.layer_norm3(tgt)
return tgt
4.完整的编码-解码结构
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
max_seq_len=512):
super(Transformer, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
self.encoder_layers = nn.ModuleList(
[EncoderLayer(d_model, nhead, dim_feedforward) for _ in range(num_encoder_layers)])
self.decoder_layers = nn.ModuleList(
[DecoderLayer(d_model, nhead, dim_feedforward) for _ in range(num_decoder_layers)])
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt):
print("embedding前src:", src.shape)
print("embedding前tgt:", tgt.shape)
src = self.embedding(src) + self.positional_encoding(src)
tgt = self.embedding(tgt) + self.positional_encoding(tgt)
print("位置编码后src:", src.shape)
print("位置编码后tgt:", tgt.shape)
# embedding前src: torch.Size([10, 32])
# embedding前tgt: torch.Size([20, 32])
# 位置编码后src: torch.Size([10, 32, 512])
# 位置编码后tgt: torch.Size([20, 32, 512])
#src_mask = self._generate_square_subsequent_mask(src.size(0)).to(src.device)
src_mask=None
tgt_mask = self._generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
print("src.size(0)",src.size(0))
print("src.size(0)",tgt.size(0))
print("tgt_mask:", tgt_mask.shape)
# tgt_mask: torch.Size([20, 10])
memory_mask = None
# Encode
for layer in self.encoder_layers:
src = layer(src, src_mask)
print("编码器输出:",src.shape)
# Decode
for layer in self.decoder_layers:
tgt = layer(tgt, src, tgt_mask, memory_mask)
################################################################################
#注意解码器接受的是编码器最终的6层末尾的输出,并且是每层解码器都会接收#
#########################################################################
output = self.fc_out(tgt)
return output
生成左下为0,右上角为-inf的掩码矩阵
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
5.main函数测试
vocab_size = 10000
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
batch_size = 32
sequence_length_src=10
sequence_length_tgt = 20
model = Transformer(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)
src = torch.randint(0, vocab_size, (sequence_length_src, batch_size)) # (source sequence length, batch size)
tgt = torch.randint(0, vocab_size, (sequence_length_tgt, batch_size)) # (target sequence length, batch size)
output = model(src, tgt)
print(output.shape) # Should be (target sequence length, batch size, vocab_size)
print(src.shape)
print(tgt.shape)
解码器输入为tgt,batch_size为32,单词长度为10,embbeding后每个单词维度为512,那么解码器输入为[10,32,512]即[ sequence_length_src,batch_size, d_model ]。
编码器输入为src。
6.完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super(PositionalEncoding, self).__init__()
self.encoding = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
self.encoding[:, 0::2] = torch.sin(position * div_term)
self.encoding[:, 1::2] = torch.cos(position * div_term)
self.encoding = self.encoding.unsqueeze(0)
def forward(self, x):
return self.encoding[:, :x.size(1)]
class EncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super(EncoderLayer, self).__init__()
self.multihead_attention = nn.MultiheadAttention(d_model, nhead)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None):
# 多头注意力
attn_output, _ = self.multihead_attention(src, src, src, attn_mask=src_mask)
src = src + self.dropout(attn_output)
src = self.layer_norm1(src)
# 全连接
ff_output = self.feed_forward(src)
src = src + self.dropout(ff_output)
src = self.layer_norm2(src)
return src
class DecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super(DecoderLayer, self).__init__()
self.multihead_attention1 = nn.MultiheadAttention(d_model, nhead)
self.multihead_attention2 = nn.MultiheadAttention(d_model, nhead)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.layer_norm1 = nn.LayerNorm(d_model)
self.layer_norm2 = nn.LayerNorm(d_model)
self.layer_norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
# masked Multi-head attention
attn_output1, _ = self.multihead_attention1(tgt, tgt, tgt, attn_mask=tgt_mask)
tgt = tgt + self.dropout(attn_output1)
tgt = self.layer_norm1(tgt)
# 接受编码器输出的 Multi-head attention
attn_output2, _ = self.multihead_attention2(tgt, memory, memory, attn_mask=memory_mask)
tgt = tgt + self.dropout(attn_output2)
tgt = self.layer_norm2(tgt)
#全连接
ff_output = self.feed_forward(tgt)
tgt = tgt + self.dropout(ff_output)
tgt = self.layer_norm3(tgt)
return tgt
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
max_seq_len=512):
super(Transformer, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
self.encoder_layers = nn.ModuleList(
[EncoderLayer(d_model, nhead, dim_feedforward) for _ in range(num_encoder_layers)])
self.decoder_layers = nn.ModuleList(
[DecoderLayer(d_model, nhead, dim_feedforward) for _ in range(num_decoder_layers)])
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt):
print("embedding前src:", src.shape)
print("embedding前tgt:", tgt.shape)
src = self.embedding(src) + self.positional_encoding(src)
tgt = self.embedding(tgt) + self.positional_encoding(tgt)
print("位置编码后src:", src.shape)
print("位置编码后tgt:", tgt.shape)
# embedding前src: torch.Size([10, 32])
# embedding前tgt: torch.Size([20, 32])
# 位置编码后src: torch.Size([10, 32, 512])
# 位置编码后tgt: torch.Size([20, 32, 512])
#src_mask = self._generate_square_subsequent_mask(src.size(0)).to(src.device)
src_mask=None
tgt_mask = self._generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
print("src.size(0)",src.size(0))
print("src.size(0)",tgt.size(0))
print("tgt_mask:", tgt_mask.shape)
# tgt_mask: torch.Size([20, 10])
memory_mask = None
# Encode
for layer in self.encoder_layers:
src = layer(src, src_mask)
print("编码器输出:",src.shape)
# Decode
for layer in self.decoder_layers:
tgt = layer(tgt, src, tgt_mask, memory_mask)
################################################################################
#注意解码器接受的是编码器最终的6层末尾的输出,并且是每层解码器都会接收#
#########################################################################
output = self.fc_out(tgt)
return output
#生成左下为0,右上角为-inf的掩码矩阵
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
if __name__ == "__main__":
vocab_size = 10000
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
batch_size = 32
sequence_length_src=10
sequence_length_tgt = 20
model = Transformer(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)
src = torch.randint(0, vocab_size, (sequence_length_src, batch_size)) size)
tgt = torch.randint(0, vocab_size, (sequence_length_tgt, batch_size))
output = model(src, tgt)
print(output.shape)
print(src.shape)
print(tgt.shape)
7.直接调用nn.Transformer模块实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super(PositionalEncoding, self).__init__()
self.encoding = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
self.encoding[:, 0::2] = torch.sin(position * div_term)
self.encoding[:, 1::2] = torch.cos(position * div_term)
self.encoding = self.encoding.unsqueeze(0)
def forward(self, x):
return self.encoding[:, :x.size(1)]
class TransformerModel(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_len=512):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)
self.fc_out = nn.Linear(d_model, vocab_size)
self.d_model = d_model
def forward(self, src, tgt):
src = self.embedding(src) * math.sqrt(self.d_model) + self.positional_encoding(src)
tgt = self.embedding(tgt) * math.sqrt(self.d_model) + self.positional_encoding(tgt)
src_mask = None
tgt_mask = self._generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
memory_mask = None
output = self.transformer(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask)
output = self.fc_out(output)
return output
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
if __name__ == "__main__":
vocab_size = 10000
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
batch_size = 32
sequence_length_src = 10
sequence_length_tgt = 20
model = TransformerModel(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)
src = torch.randint(0, vocab_size, (sequence_length_src, batch_size))
tgt = torch.randint(0, vocab_size, (sequence_length_tgt, batch_size))
output = model(src, tgt)
print(output.shape)
print(src.shape)
print(tgt.shape)