论文: Attention is All You Need
在上一篇博客,笔者简述了编码器部分的理论并用pytorch实现了编码器部分的代码。
此篇博客便是承接上文,继续讲述Transformer模型中解码器的代码实现。
其实编码器和解码器模型的部分层都是相同的,例如规范化层,多头注意力机制,前馈全连接层等。
只是解码器比编码器子层连接结构中多了一个多头自注意力机制。
所以,有相同知识点请翻看上篇博客。
文本嵌入层
文本嵌入层的作用就是,将文本中词汇的数字表示转变为高维的向量表示。旨在高维空间捕捉词汇间的关系。
代码实现请翻看上篇博客。
位置编码层
数据经过文本嵌入层就会流向位置编码层。
Transformer 和 LSTM 的最大区别,就是 LSTM的训练是迭代的、串行的,必须要等当前字处理完,才可以处理下一个字。而Transformer的训练时并行的,即所有字是同时训练的,这样就大大增加了计算效率。Transformer 使用了位置嵌入(Positional Encoding) 来理解语言的顺序,使用自注意力机制(Self Attention Mechanism)和全连接层进行计算。
Transformer结构中没有针对词汇位置信息的处理,因此需要在Embedding最后加入位置编码器。
将词汇位置不同可能会产生不同语义的信息加入到词嵌入张量中,以弥补位置信息的缺失。
代码实现请翻看上篇博客。
多头自注意力机制
上图为Transformer模型结构,右侧就是解码器部分。
我们可以看到解码器部分由n个编码器层堆叠而成
每个编码器层由三个子层连接结构组成
第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接。
第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接。
第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接。
从图中我们可以看出,当Q=K=V时,多头注意力机制就成了多头自注意力机制。
所以从代码层面来说,只要传参时,参数Q,K,V都相同,这时多头注意力就变成了多头自注意力。
代码实现请翻看上篇博客。
多头注意力机制
所谓的多头,就是使用一组线性变化层对Q,K,V分别进行线性变换。
这些变换不会改变原有张量的尺寸,因此每个变换矩阵都是方阵。
每个头从词义层面分割输出张量 也就是每个头都想获得一组Q,K,V。
是句子中的每个词的表示只获得一部分,也就是只分割了最后一维的词嵌入向量。
把每个头的获得的输入送到注意力机制中,就形成多头注意力机制。
代码实现请翻看上篇博客。
规范化层
随着网络层数的增加,通过多层的计算后参数可能开始出现过大或过小的情况
这样可能导致学习过程出现异常,模型收敛过慢
因此添加规范化层进行数值的规范化,使其特征数值在合理范围内
代码实现请翻看上篇博客。
前馈全连接层
前馈全连接层:两层全连接层
作用:考虑注意力机制可能对复杂过程的拟合程度不够,通过增加两层网络来增强模型的能力
代码实现请翻看上篇博客。
输出层
Transformer模型的输出层就是一个全连接层加一个Softmax处理。
class Generator(nn.Module):
def __init__(self,d_model,vocab):
'''
:param d_model: 词嵌入维度
:param vocab: 词表大小
'''
super(Generator, self).__init__()
self.project = nn.Linear(d_model,vocab)
def forward(self,x):
# 使用log_softmax是因为和我们这个pytorch版本的损失函数实现有关
# log_softmax就是对softmax的结果又取了对数,因为对数函数是单调递增函数
return F.log_softmax(self.project(x),dim=-1)
代码实现Transformer
Github代码:Transformer-PyTorch
Transformer的编码器只有一个输出,
而这个输出将会传入解码器部分的每一个解码器层,
充当每一个解码器层中的第二个子层连接结构的多头注意力机制的Q,K。
(我知道这句话有点绕,但是请读者务必理解好)
这也是Transformer模型与seq2seq模型最大的不同。
这样就可以并行处理所有的词向量,发挥多头的作用!
Transformer(
(TE): TextEmbedding(
(lut): Embedding(500, 512, padding_idx=0)
)
(PE): PositionalEnconding(
(dpot): Dropout(p=0.2, inplace=False)
)
(EN): Encoder(
(layers): ModuleList(
(0): EncoderLayer(
(MultiHeadAtten): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(1): EncoderLayer(
(MultiHeadAtten): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(2): EncoderLayer(
(MultiHeadAtten): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(3): EncoderLayer(
(MultiHeadAtten): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(4): EncoderLayer(
(MultiHeadAtten): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(5): EncoderLayer(
(MultiHeadAtten): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(6): EncoderLayer(
(MultiHeadAtten): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(7): EncoderLayer(
(MultiHeadAtten): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
)
)
(DE): Decoder(
(layers): ModuleList(
(0): DecoderLayer(
(Self_MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(1): DecoderLayer(
(Self_MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(2): DecoderLayer(
(Self_MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(3): DecoderLayer(
(Self_MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(4): DecoderLayer(
(Self_MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(5): DecoderLayer(
(Self_MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(6): DecoderLayer(
(Self_MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
(7): DecoderLayer(
(Self_MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(MHA): MultiHeadAtten(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dpot): Dropout(p=0.1, inplace=False)
)
(FF): FeedForward(
(w1): Linear(in_features=512, out_features=64, bias=True)
(w2): Linear(in_features=64, out_features=512, bias=True)
(dpot): Dropout(p=0.2, inplace=False)
)
(SubLayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dpot): Dropout(p=0.2, inplace=False)
)
)
)
)
)
(G): Generator(
(project): Linear(in_features=512, out_features=500, bias=True)
)
)