Transformer模型解码器部分实现

说明:部分内容来自于网络教程,如有侵权请联系本人删除

教程链接:2.4.2解码器-part2_哔哩哔哩_bilibili

1.解码器层的作用

作为解码器的组成单元,每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程。

代码实现:

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
    '''
        size: 词嵌入维度大小,也代表解码器的尺寸
        self_attn: 多头自注意力对象,也就是说这个注意力机制需要 Q=K=V
        src_attn: 多头注意力对象,这里 Q!=K=V
        feed_forward: 前馈全连接层
    '''
        super(DecoderLayer, self).__init__()
    
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self。dropout = dropout
        self.sublayer = clones(SublayerConnection(size, dropout),3)

    def forward(self, x, memory, source_mask, target_mask):
    '''
        x: 上一层的输入
        memory: 来自编码器的语义存储变量memory
        source_mask: 源数据掩码张量
        target_mask: 目标数据掩码张量
    '''
        m = memory
        # 使用target_mask,为了将编码时未来的信息遮掩
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, target_mask))
        # 使用source_mask,为了遮掩掉对结果信息无用的数据
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, source_mask))
        # 最终输出由编码器输入和目标数据一同作用的特征提取结果
        return self.sublayer[2](x, self.feed_forward)
    
    

2.解码器作用

根据编码器的结果以及上一次预测的结果,对下一次可能出现的值进行特征表示

代码实现:(实际上就是解码器层的堆叠)

class Decoder(nn.Module):    
    def __init__(self, Layer, N):
        
        super(Decoder, self).__init__()
        
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer, size)

    def forward(self, x, memory, source_mask, target_mask):
        
        for layer in self.layers:
            x = layer(x, memory, source_mask, target_mask)
        # 输出解码过程的最终表示 
        return self.norm(x)
    

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

APPLECHARLOTTE

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值