Transformer解码器原理解析

Transformer解码器原理
在这里插入图片描述
解码器层

import torch
import torch.nn as nn
class DecoderLayer(nn.Module):
        def __init__(self,size,self_attn,src_attn,dropout):
                super().__init__()
                self.size=size
                self.self_attn=self_attn
                self.src_attn=src_attn
                self.feed_forward=feed_forward
                self.sublayer=clones(SubLayerconnection(),3)

        forward(input,memory,source_mask,target_mask):
                m=memory
                input=self.sublayer[0](input,lambda input:self.self_attn(input,input,input,target_mask)
                input=self.sublayer[1](input,lambda input:self.src_attn(input,m,m,source_target)
                return self.sublayer[2](input,self.feed_forward)


dl = DecoderLayer(size, self_attn, src_attn, ff, dropout)
dl_result = dl(x, memory, source_mask, target_mask)



解码器

class Decoder(nn.Module):
        def __init__(self,layer,N):
                super().__init__()
                self.layers=clones(layer,N)
                self.norm=NormLayer(layer.size)

        def forward(self,input,memory,source_mask,target_mask):
                for layer in self.layers:

                        input=layer(input,memory,source_mask,target_mask)
                return self.norm(input)

c=copy.deepcopy
attn=MultiHeadedAttention(head,d_model)
feed_forward=PositionalwisefeedForward(d_model,d_ff,dropout)

layer=DecoderLayer(size,c(attn),c(attn),c(feed_forward),dropout)
de=Decoder(layer,N)

output_de=de(input,memory,source_mask,target_mask)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值