大模型基础——从零实现一个Transformer(5)

大模型基础——从零实现一个Transformer(1)-CSDN博客

大模型基础——从零实现一个Transformer(2)-CSDN博客

大模型基础——从零实现一个Transformer(3)-CSDN博客

大模型基础——从零实现一个Transformer(4)-CSDN博客


一、前言

上一篇文章已经把Encoder模块和Decoder模块都已经实现了,

接下来来实现完整的Transformer

二、Transformer

Transformer整体架构如上,直接把我们实现的Encoder 和Decoder模块引入,开始堆叠

import torch
from torch import  nn,Tensor
from torch.nn import  Embedding

#引入自己实现的模块
from llm_base.embedding.PositionalEncoding import PositionalEmbedding
from llm_base.encoder import Encoder
from llm_base.decoder import Decoder
from llm_base.mask.target_mask import make_target_mask

class Transformer(nn.Module):
    def __init__(self,
                 source_vocab_size:int,
                 target_vocab_size:int,
                 d_model: int = 512,
                 n_heads: int = 8,
                 num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 d_ff: int = 2048,
                 dropout: float = 0.1,
                 max_positions:int = 5000,
                 pad_idx: int = 0,
                 norm_first: bool=False) -> None:
        '''

        :param source_vocab_size: size of the source vocabulary.
        :param target_vocab_size: size of the target vocabulary.
        :param d_model: dimension of embeddings. Defaults to 512.
        :param n_heads: number of heads. Defaults to 8.
        :param num_encoder_layers: number of encoder blocks. Defaults to 6.
        :param num_decoder_layers: number of decoder blocks. Defaults to 6.
        :param d_ff: dimension of inner feed-forward network. Defaults to 2048.
        :param dropout: dropout ratio. Defaults to 0.1.
        :param max_positions: maximum sequence length for positional encoding. Defaults to 5000.
        :param pad_idx: pad index. Defaults to 0.
        :param norm_first: if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).
                Otherwise it's done after(Post-Norm). Default to False.
        '''
        super().__init__()
        # Token embedding
        self.src_embeddings = Embedding(source_vocab_size,d_model)
        self.target_embeddings = Embedding(target_vocab_size,d_model)

        # Position embedding
        self.encoder_pos = PositionalEmbedding(d_model,dropout,max_positions)
        self.decoder_pos = PositionalEmbedding(d_model,dropout,max_positions)


        # 编码层定义
        self.encoder = Encoder(d_model,num_encoder_layers,n_heads,d_ff,dropout,norm_first)
        # 解码层定义
        self.decoder = Decoder(d_model,num_decoder_layers,n_heads,d_ff,dropout,norm_first)

        self.pad_idx = pad_idx


    def encode(self,
               src:Tensor,
               src_mask: Tensor=None,
               keep_attentions: bool=False) -> Tensor:
        '''
        编码过程
        :param src: (batch_size, src_seq_length) the sequence to the encoder
        :param src_mask: (batch_size, 1, src_seq_length) the mask for the sequence
        :param keep_attentions:  whether keep attention weigths or not. Defaults to False.
        :return: (batch_size, seq_length, d_model) encoder output
        '''

        src_embedding_tensor = self.src_embeddings(src)
        src_embedded = self.encoder_pos(src_embedding_tensor)

        return self.encoder(src_embedded,src_mask,keep_attentions)

    def decode(self,
               target_tensor: Tensor,
               memory: Tensor,
               target_mask: Tensor = None,
               memory_mask: Tensor = None,
               keep_attentions: bool = False) ->Tensor:
        '''

        :param target_tensor: (batch_size, tgt_seq_length) the sequence to the decoder.
        :param memory: (batch_size, src_seq_length, d_model) the  sequence from the last layer of the encoder.
        :param target_mask: (batch_size, 1, 1, tgt_seq_length) the mask for the target sequence. Defaults to None.
        :param memory_mask: (batch_size, 1, 1, src_seq_length) the mask for the memory sequence. Defaults to None.
        :param keep_attentions:  whether keep attention weigths or not. Defaults to False.
        :return: output (batch_size, tgt_seq_length, tgt_vocab_size)
        '''
        target_embedding_tensor = self.target_embeddings(target_tensor)
        target_embedded = self.decoder_pos(target_embedding_tensor)

        # logits (batch_size, target_seq_length, d_model)
        logits = self.decoder(target_embedded,memory,target_mask,memory_mask,keep_attentions)
        return  logits


    def forward(self,
                src: Tensor,
                target: Tensor,
                src_mask: Tensor=None,
                target_mask: Tensor=None,
                keep_attention:bool=False)->Tensor:
        '''

        :param src: (batch_size, src_seq_length) the sequence to the encoder
        :param target:  (batch_size, tgt_seq_length) the sequence to the decoder
        :param src_mask:
        :param target_mask:
        :param keep_attention: whether keep attention weigths or not. Defaults to False.
        :return: (batch_size, tgt_seq_length, tgt_vocab_size)
        '''
        memory = self.encode(src,src_mask,keep_attention)
        return  self.decode(target,memory,target_mask,src_mask,keep_attention)

三、测试

写个简单的main函数,测试一下整体网络是否正常

if __name__ == '__main__':
    source_vocab_size = 300
    target_vocab_size = 300
    # padding对应的index,一般都是0
    pad_idx = 0

    batch_size = 1
    max_positions = 20

    model = Transformer(source_vocab_size=source_vocab_size,
                        target_vocab_size=target_vocab_size)

    src_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))
    target_tensor = torch.randint(0,source_vocab_size,(batch_size,max_positions))

    ## 最后5位置是padding
    src_tensor[:,-5:] = 0

    ## 最后10位置是padding
    target_tensor[:, -10:] = 0

    src_mask = (src_tensor != pad_idx).unsqueeze(1)
    targe_mask =  make_target_mask(target_tensor)

    logits = model(src_tensor,target_tensor,src_mask,targe_mask)
    print(logits.shape) 
    #torch.Size([1, 20, 512])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值