编码器解码器架构(模板)
简单来说就是:编码器负责特征抽取,解码器负责输出,解码器也有自己的输入
编码器
from torch import nn
#@save
class Encoder(nn.Module):
"""编码器-解码器架构的基本编码器接口"""
def __init__(self, **kwargs):
super(Encoder, self).__init__(**kwargs)
def forward(self, X, *args):
raise NotImplementedError
和正常的模型一样
解码器
#@save
class Decoder(nn.Module):
"""编码器-解码器架构的基本解码器接口"""
def __init__(self, **kwargs):
super(Decoder, self).__init__(**kwargs)
def init_state(self, enc_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
init_state(self, enc_outputs, *args):
用来接收编码器的输出,转换成state
forward(self, X, state):
且在forward中解码器也有自己的输入
合并编码器和解码器
#@save
class EncoderDecoder(nn.Module):
"""编码器-解码器架构的基类"""
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
forward(self, enc_X, dec_X, *args):
enc_X:编码器输入
dec_X:解码器输入
enc_outputs = self.encoder(enc_X, *args)
enc_outputs:经过编码器得到编码器输出
self.decoder.init_state(enc_outputs, *args)
通过解码器的
init_state
方法将编码器的输出变成一个状态供解码器使用
self.decoder(dec_X, dec_state)
根据刚刚得到的状态和解码器自己的输入,得到解码器最终的输出
接下来几节使用的模型都基于这个结构。