编码器-解码器

一、编码器的基本架构
forward层:传入一个x,输出一个out

from torch import nn
class Encoder(nn.Module):
    def __init__(self,**kwargs):
        super(Encoder, self).__init__(**kwargs)
    def forward(self,x,*args):
        raise NotImplementedError

二、解码器的基本架构
def init_state(self,enc_outputs,*args):利用encoder的输出建立中间状态(decoder的初始状态)
forward()内部传入encoder的输入和decoder的初始状态(随后会不断变化)
这里decoder的初始状态也就是:encoder压缩成的向量,也就是编好的码

class Decoder(nn.Module):
    def __init__(self,**kwargs):
        super(Decoder, self).__init__(**kwargs)
    #enc_outputs是encoder的输出,初始化状态,就是用encoder的东西来转化成直接想要的状态
    def init_state(self,enc_outputs,*args):
        raise NotImplementedError
    #可以有额外的输入,state:一开始从encoder那翻译过来,之后随着forward可以不断变化
    def forward(self,x,state):
        raise NotImplementedError

三、编码器-解码器架构
将上面定义的编码解码器加载进来

class EncoderDecoder(nn.Module):
    def __init__(self,encoder,decoder,**kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder=encoder
        self.decoder=decoder
    #输入编号的码,输入编码器和解码器的输入
    def forward(self,enc_x,dec_x,*args):
        enc_outputs = self.encoder(enc_x,*args)
        dec_state=self.decoder.init_state(enc_outputs,*args)
        return self.decoder(dec_x,dec_state)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值