一、编码器的基本架构
forward层:传入一个x,输出一个out
from torch import nn
class Encoder(nn.Module):
def __init__(self,**kwargs):
super(Encoder, self).__init__(**kwargs)
def forward(self,x,*args):
raise NotImplementedError
二、解码器的基本架构
def init_state(self,enc_outputs,*args):利用encoder的输出建立中间状态(decoder的初始状态)
forward()内部传入encoder的输入和decoder的初始状态(随后会不断变化)
这里decoder的初始状态也就是:encoder压缩成的向量,也就是编好的码
class Decoder(nn.Module):
def __init__(self,**kwargs):
super(Decoder, self).__init__(**kwargs)
#enc_outputs是encoder的输出,初始化状态,就是用encoder的东西来转化成直接想要的状态
def init_state(self,enc_outputs,*args):
raise NotImplementedError
#可以有额外的输入,state:一开始从encoder那翻译过来,之后随着forward可以不断变化
def forward(self,x,state):
raise NotImplementedError
三、编码器-解码器架构
将上面定义的编码解码器加载进来
class EncoderDecoder(nn.Module):
def __init__(self,encoder,decoder,**kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder=encoder
self.decoder=decoder
#输入编号的码,输入编码器和解码器的输入
def forward(self,enc_x,dec_x,*args):
enc_outputs = self.encoder(enc_x,*args)
dec_state=self.decoder.init_state(enc_outputs,*args)
return self.decoder(dec_x,dec_state)