python手撕代码——完整的transformer代码

1.位置编码

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)]

2.编码器

class EncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super(EncoderLayer, self).__init__()
        self.multihead_attention = nn.MultiheadAttention(d_model, nhead)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, src, src_mask=None):
        # 多头注意力
        attn_output, _ = self.multihead_attention(src, src, src, attn_mask=src_mask)
        src = src + self.dropout(attn_output)
        src = self.layer_norm1(src)

        # 全连接
        ff_output = self.feed_forward(src)
        src = src + self.dropout(ff_output)
        src = self.layer_norm2(src)
        return src

3.解码器

class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super(DecoderLayer, self).__init__()
        self.multihead_attention1 = nn.MultiheadAttention(d_model, nhead)
        self.multihead_attention2 = nn.MultiheadAttention(d_model, nhead)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.layer_norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
          #  masked Multi-head attention
        attn_output1, _ = self.multihead_attention1(tgt, tgt, tgt, attn_mask=tgt_mask)
        tgt = tgt + self.dropout(attn_output1)
        tgt = self.layer_norm1(tgt)

        # 接受编码器输出的 Multi-head attention
        attn_output2, _ = self.multihead_attention2(tgt, memory, memory, attn_mask=memory_mask)
        tgt = tgt + self.dropout(attn_output2)
        tgt = self.layer_norm2(tgt)

        #全连接
        ff_output = self.feed_forward(tgt)
        tgt = tgt + self.dropout(ff_output)
        tgt = self.layer_norm3(tgt)
        return tgt

4.完整的编码-解码结构

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
                 max_seq_len=512):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, nhead, dim_feedforward) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, nhead, dim_feedforward) for _ in range(num_decoder_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        print("embedding前src:", src.shape)
        print("embedding前tgt:", tgt.shape)
        src = self.embedding(src) + self.positional_encoding(src)
        tgt = self.embedding(tgt) + self.positional_encoding(tgt)
        print("位置编码后src:", src.shape)
        print("位置编码后tgt:", tgt.shape)

        # embedding前src: torch.Size([10, 32])
        # embedding前tgt: torch.Size([20, 32])
        # 位置编码后src: torch.Size([10, 32, 512])
        # 位置编码后tgt: torch.Size([20, 32, 512])

        #src_mask = self._generate_square_subsequent_mask(src.size(0)).to(src.device)
        src_mask=None
        tgt_mask = self._generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
        print("src.size(0)",src.size(0))
        print("src.size(0)",tgt.size(0))

    
        print("tgt_mask:", tgt_mask.shape)

  
        # tgt_mask: torch.Size([20, 10])
        memory_mask = None

        # Encode
        for layer in self.encoder_layers:
            src = layer(src, src_mask)

        print("编码器输出:",src.shape)
        # Decode
        for layer in self.decoder_layers:
            tgt = layer(tgt, src, tgt_mask, memory_mask)
            
################################################################################
#注意解码器接受的是编码器最终的6层末尾的输出,并且是每层解码器都会接收#
#########################################################################
        output = self.fc_out(tgt)
        return output

生成左下为0,右上角为-inf的掩码矩阵
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

在这里插入图片描述

5.main函数测试

    vocab_size = 10000
    d_model = 512
    nhead = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    dim_feedforward = 2048

    batch_size = 32
    sequence_length_src=10
    sequence_length_tgt = 20
    model = Transformer(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)

    src = torch.randint(0, vocab_size, (sequence_length_src, batch_size))  # (source sequence length, batch size)
    tgt = torch.randint(0, vocab_size, (sequence_length_tgt, batch_size))  # (target sequence length, batch size)

    output = model(src, tgt)
    print(output.shape)  # Should be (target sequence length, batch size, vocab_size)
    print(src.shape)
    print(tgt.shape)

解码器输入为tgt,batch_size为32,单词长度为10,embbeding后每个单词维度为512,那么解码器输入为[10,32,512]即[ sequence_length_src,batch_size, d_model ]。
编码器输入为src。

6.完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)]

class EncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super(EncoderLayer, self).__init__()
        self.multihead_attention = nn.MultiheadAttention(d_model, nhead)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, src, src_mask=None):
        # 多头注意力
        attn_output, _ = self.multihead_attention(src, src, src, attn_mask=src_mask)
        src = src + self.dropout(attn_output)
        src = self.layer_norm1(src)

        # 全连接
        ff_output = self.feed_forward(src)
        src = src + self.dropout(ff_output)
        src = self.layer_norm2(src)
        return src


class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super(DecoderLayer, self).__init__()
        self.multihead_attention1 = nn.MultiheadAttention(d_model, nhead)
        self.multihead_attention2 = nn.MultiheadAttention(d_model, nhead)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.layer_norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
          #  masked Multi-head attention
        attn_output1, _ = self.multihead_attention1(tgt, tgt, tgt, attn_mask=tgt_mask)
        tgt = tgt + self.dropout(attn_output1)
        tgt = self.layer_norm1(tgt)

        # 接受编码器输出的 Multi-head attention
        attn_output2, _ = self.multihead_attention2(tgt, memory, memory, attn_mask=memory_mask)
        tgt = tgt + self.dropout(attn_output2)
        tgt = self.layer_norm2(tgt)

        #全连接
        ff_output = self.feed_forward(tgt)
        tgt = tgt + self.dropout(ff_output)
        tgt = self.layer_norm3(tgt)
        return tgt


class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
                 max_seq_len=512):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, nhead, dim_feedforward) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, nhead, dim_feedforward) for _ in range(num_decoder_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        print("embedding前src:", src.shape)
        print("embedding前tgt:", tgt.shape)
        src = self.embedding(src) + self.positional_encoding(src)
        tgt = self.embedding(tgt) + self.positional_encoding(tgt)
        print("位置编码后src:", src.shape)
        print("位置编码后tgt:", tgt.shape)

        # embedding前src: torch.Size([10, 32])
        # embedding前tgt: torch.Size([20, 32])
        # 位置编码后src: torch.Size([10, 32, 512])
        # 位置编码后tgt: torch.Size([20, 32, 512])

        #src_mask = self._generate_square_subsequent_mask(src.size(0)).to(src.device)
        src_mask=None
        tgt_mask = self._generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
        print("src.size(0)",src.size(0))
        print("src.size(0)",tgt.size(0))

    
        print("tgt_mask:", tgt_mask.shape)

  
        # tgt_mask: torch.Size([20, 10])
        memory_mask = None

        # Encode
        for layer in self.encoder_layers:
            src = layer(src, src_mask)

        print("编码器输出:",src.shape)
        # Decode
        for layer in self.decoder_layers:
            tgt = layer(tgt, src, tgt_mask, memory_mask)
            
################################################################################
#注意解码器接受的是编码器最终的6层末尾的输出,并且是每层解码器都会接收#
#########################################################################
        output = self.fc_out(tgt)
        return output

#生成左下为0,右上角为-inf的掩码矩阵
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask



if __name__ == "__main__":
    vocab_size = 10000
    d_model = 512
    nhead = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    dim_feedforward = 2048

    batch_size = 32
    sequence_length_src=10
    sequence_length_tgt = 20
    model = Transformer(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)

    src = torch.randint(0, vocab_size, (sequence_length_src, batch_size))  size)
    tgt = torch.randint(0, vocab_size, (sequence_length_tgt, batch_size)) 

    output = model(src, tgt)
    print(output.shape) 
    print(src.shape)
    print(tgt.shape)

7.直接调用nn.Transformer模块实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)]


class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_len=512):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, src, tgt):
        src = self.embedding(src) * math.sqrt(self.d_model) + self.positional_encoding(src)
        tgt = self.embedding(tgt) * math.sqrt(self.d_model) + self.positional_encoding(tgt)
        
        src_mask = None
        tgt_mask = self._generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
        memory_mask = None

        output = self.transformer(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask)
        output = self.fc_out(output)
        return output

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask



if __name__ == "__main__":
    vocab_size = 10000
    d_model = 512
    nhead = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    dim_feedforward = 2048

    batch_size = 32
    sequence_length_src = 10
    sequence_length_tgt = 20
    model = TransformerModel(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)

    src = torch.randint(0, vocab_size, (sequence_length_src, batch_size))  
    tgt = torch.randint(0, vocab_size, (sequence_length_tgt, batch_size))

    output = model(src, tgt)
    print(output.shape) 
    print(src.shape)
    print(tgt.shape)

  • 11
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大泽泽的小可爱

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值