大模型基础——从零实现一个Transformer(1)-CSDN博客
大模型基础——从零实现一个Transformer(2)-CSDN博客
大模型基础——从零实现一个Transformer(3)-CSDN博客
大模型基础——从零实现一个Transformer(4)-CSDN博客
一、前言
上一篇文章已经把Encoder模块和Decoder模块都已经实现了,
接下来来实现完整的Transformer
二、Transformer
Transformer整体架构如上,直接把我们实现的Encoder 和Decoder模块引入,开始堆叠
import torch
from torch import nn,Tensor
from torch.nn import Embedding
#引入自己实现的模块
from llm_base.embedding.PositionalEncoding import PositionalEmbedding
from llm_base.encoder import Encoder
from llm_base.decoder import Decoder
from llm_base.mask.target_mask import make_target_mask
class Transformer(nn.Module):
def __init__(self,
source_vocab_size:int,
target_vocab_size:int,
d_model: int = 512,
n_heads: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
d_ff: int = 2048,
dropout: float = 0.1,
max_positions:int = 5000,
pad_idx: int = 0,
norm_first: bool=False) -> None:
'''
:param source_vocab_size: size of the source vocabulary.
:param target_vocab_size: size of the target vocabulary.
:param d_model: dimension of embeddings. Defaults to 512.
:param n_heads: number of heads. Defaults to 8.
:param num_encoder_layers: number of encoder blocks. Defaults to 6.
:param num_decoder_layers: number of decoder blocks. Defaults to 6.
:param d_ff: dimension of inner feed-forward network. Defaults to 2048.
:param dropout: dropout ratio. Defaults to 0.1.
:param max_positions: maximum sequence length for positional encoding. Defaults to 5000.
:param pad_idx: pad index. Defaults to 0.
:param norm_first: if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).
Otherwise it's done after(Post-Norm). Default to False.
'''
super().__init__()
# Token embedding
self.src_embeddings = Embedding(source_vocab_size,d_model)
self.target_embeddings = Embedding(target_vocab_size,d_model)
# Position embedding
self.encoder_pos = PositionalEmbedding(d_model,dropout,max_positions)
self.decoder_pos = PositionalEmbedding(d_model,dropout,max_positions)
# 编码层定义
self.encoder = Encoder(d_model,num_encoder_layers,n_heads,d_ff,dropout,norm_first)
# 解码层定义
self.decoder = Decoder(d_model,num_decoder_layers,n_heads,d_ff,dropout,norm_first)
self.pad_idx = pad_idx
def encode(self,
src:Tensor,
src_mask: Tensor=None,
keep_attentions: bool=False) -> Tensor:
'''
编码过程
:param src: (batch_size, src_seq_length) the sequence to the encoder
:param src_mask: (batch_size, 1, src_seq_length) the mask for the sequence
:param keep_attentions: whether keep attention weigths or not. Defaults to False.
:return: (batch_size, seq_length, d_model) encoder output
'''
src_embedding_tensor = self.src_embeddings(src)
src_embedded = self.encoder_pos(src_embedding_tensor)
return self.encoder(src_embedded,src_mask,keep_attentions)
def decode(self,
target_tensor: Tensor,
memory: Tensor,
target_mask: Tensor = None,
memory_mask: Tensor = None,
keep_attentions: bool = False) ->Tensor:
'''
:param target_tensor: (batch_size, tgt_seq_length) the sequence to the decoder.
:param memory: (batch_size, src_seq_length, d_model) the sequence from the last layer of the encoder.
:param target_mask: (batch_size, 1, 1, tgt_seq_length) the mask for the target sequence. Defaults to None.
:param memory_mask: (batch_size, 1, 1, src_seq_length) the mask for the memory sequence. Defaults to None.
:param keep_attentions: whether keep attention weigths or not. Defaults to False.
:return: output (batch_size, tgt_seq_length, tgt_vocab_size)
'''
target_embedding_tensor = self.target_embeddings(target_tensor)
target_embedded = self.decoder_pos(target_embedding_tensor)
# logits (batch_size, target_seq_length, d_model)
logits = self.decoder(target_embedded,memory,target_mask,memory_mask,keep_attentions)
return logits
def forward(self,
src: Tensor,
target: Tensor,
src_mask: Tensor=None,
target_mask: Tensor=None,
keep_attention:bool=False)->Tensor:
'''
:param src: (batch_size, src_seq_length) the sequence to the encoder
:param target: (batch_size, tgt_seq_length) the sequence to the decoder
:param src_mask:
:param target_mask:
:param keep_attention: whether keep attention weigths or not. Defaults to False.
:return: (batch_size, tgt_seq_length, tgt_vocab_size)
'''
memory = self.encode(src,src_mask,keep_attention)
return self.decode(target,memory,target_mask,src_mask,keep_attention)
三、测试
写个简单的main函数,测试一下整体网络是否正常
if __name__ == '__main__':
source_vocab_size = 300
target_vocab_size = 300
# padding对应的index,一般都是0
pad_idx = 0
batch_size = 1
max_positions = 20
model = Transformer(source_vocab_size=source_vocab_size,
target_vocab_size=target_vocab_size)
src_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))
target_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))
## 最后5位置是padding
src_tensor[:,-5:] = 0
## 最后10位置是padding
target_tensor[:, -10:] = 0
src_mask = (src_tensor != pad_idx).unsqueeze(1)
targe_mask = make_target_mask(target_tensor)
logits = model(src_tensor,target_tensor,src_mask,targe_mask)
print(logits.shape)
#torch.Size([1, 20, 512])