编码器-解码器结构的代码实现
class EncoderDecoder(nn.Module):
def __init__(self,encoder, decoder, source_embed, target_embed, generator)
super().__init__()
self.encoder=encoder
self.decoder=decoder
self.src_embed=source_embed
self.tgt_embed=target_embed
self.generator=generator
def forward(self,source,target,source_mask,target_mask):
return self.decode(self.encode(source,source_mask),source_mask,target,target_mask)
def encode(self,source,source_mask):
return self.encoder(self.src_embed(source),source_mask)
def decode(self,memory,source_mask,target,target_mask):
return self.decoder(self.tgt_embed(target),memory,source_mask,target,target_mask)