Transformer解码器
import copy
from torch import nn
from norm import Norm
from multi_head_attention import MultiHeadAttention
from feed_forward import FeedForward
from pos_encoder import PositionalEncoder
def get_clones(module, N):
"""
Create N identical layers.
Args:
module: The module (layer) to be duplicated.
N: The number of copies to create.
Returns:
A ModuleList containing N identical copies of the module.
"""
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class DecoderLayer(nn.Module):
def __init__(self, d_model=512, d_ff=2048, heads=8, dropout=0.1):
super(DecoderLayer, self).__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.norm_3 = Norm(d_model)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
self.dropout_3 = nn.Dropout(dropout)
self.attn_1 = MultiHeadAttention(heads, d_model, dropout)
self.attn_2 = MultiHeadAttention(heads, d_model, dropout)
self.ff = FeedForward(d_model, d_ff, dropout)
def forward(self, x, e_outputs, src_mask=None, trg_mask=None):
attn_output_1 = self.attn_1(x, x, x, trg_mask)
attn_output_1 = self.dropout_1(attn_output_1)
x = x + attn_output_1
x = self.norm_1(x)
attn_output_2 = self.attn_2(x, e_outputs, e_outputs, src_mask)
attn_output_2 = self.dropout_2(attn_output_2)
x = x + attn_output_2
x = self.norm_2(x)
ff_output = self.ff(x)
ff_output = self.dropout_3(ff_output)
x = x + ff_output
x = self.norm_3(x)
return x
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size=1000, max_seq_len=50, d_model=512, d_ff=2048, N=6, heads=8, dropout=0.1):
super(TransformerDecoder, self).__init__()
'''
vocab_size: 词典大小
max_seq_len: 序列最大长度
d_model: 词嵌入大小
d_ff: 前馈层隐层维度
N: 编码器中transformer的个数
heads: 多头个数
dropout: dropout比例
'''
self.N = N
self.embed = nn.Embedding(vocab_size, d_model)
self.pe = PositionalEncoder(max_seq_len, d_model)
self.layers = get_clones(DecoderLayer(d_model, d_ff, heads, dropout), N)
self.norm = Norm(d_model)
def forward(self, trg, e_outputs, src_mask, trg_mask):
'''
trg: 解码器的输入
e_outputs:编码器的输出
src_mask
trg_mask
'''
x = self.embed(trg)
x = self.pe(x)
print("d_x.shape", x.shape)
print("e_outputs.shape", e_outputs.shape)
for i in range(self.N):
x = self.layers[i](x, e_outputs, src_mask, trg_mask)
x = self.norm(x)
return x