Transformer详解(6)-解码器

Transformer解码器

import copy
from torch import nn
from norm import Norm
from multi_head_attention import MultiHeadAttention
from feed_forward import FeedForward
from pos_encoder import PositionalEncoder


def get_clones(module, N):
    """
    Create N identical layers.

    Args:
    module: The module (layer) to be duplicated.
    N: The number of copies to create.

    Returns:
    A ModuleList containing N identical copies of the module.
    """
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class DecoderLayer(nn.Module):
    def __init__(self, d_model=512, d_ff=2048, heads=8, dropout=0.1):
        super(DecoderLayer, self).__init__()

        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        self.norm_3 = Norm(d_model)

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        self.dropout_3 = nn.Dropout(dropout)

        self.attn_1 = MultiHeadAttention(heads, d_model, dropout)
        self.attn_2 = MultiHeadAttention(heads, d_model, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)

    def forward(self, x, e_outputs, src_mask=None, trg_mask=None):
        attn_output_1 = self.attn_1(x, x, x, trg_mask)  # 解码器输入
        attn_output_1 = self.dropout_1(attn_output_1)
        x = x + attn_output_1  # 残差层
        x = self.norm_1(x)  # 层归一化

        attn_output_2 = self.attn_2(x, e_outputs, e_outputs, src_mask)  # 加入编码器的输出
        attn_output_2 = self.dropout_2(attn_output_2)
        x = x + attn_output_2  # 残差层
        x = self.norm_2(x)  # 层归一化

        ff_output = self.ff(x)  # 前馈层
        ff_output = self.dropout_3(ff_output)
        x = x + ff_output  # 残差层
        x = self.norm_3(x)  # 层归一化

        return x


class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size=1000, max_seq_len=50, d_model=512, d_ff=2048, N=6, heads=8, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        '''
        vocab_size: 词典大小
        max_seq_len: 序列最大长度
        d_model: 词嵌入大小
        d_ff: 前馈层隐层维度
        N: 编码器中transformer的个数
        heads: 多头个数
        dropout: dropout比例
        '''

        self.N = N
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = PositionalEncoder(max_seq_len, d_model)
        self.layers = get_clones(DecoderLayer(d_model, d_ff, heads, dropout), N)
        self.norm = Norm(d_model)

    def forward(self, trg, e_outputs, src_mask, trg_mask):
        '''
        trg: 解码器的输入
        e_outputs:编码器的输出
        src_mask
        trg_mask
        '''
        x = self.embed(trg)
        x = self.pe(x)
        print("d_x.shape", x.shape)
        print("e_outputs.shape", e_outputs.shape)
        for i in range(self.N):
            x = self.layers[i](x, e_outputs, src_mask, trg_mask)
        x = self.norm(x)
        # print("d_x.shape", x.shape)
        return x
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值