Transformer代码讲解
原论文:https://arxiv.org/pdf/1706.03762v5.pdf
本文将从以下部分进行讲解:
一、Transformer结构展开图
1.原图
2.结构展开图
二、Transformer代码
1.数据预处理
2. 代码拆分
2.1 positional encoding
2.2 pad mask
2.3 subsequence mask
2.4 ScaledDotProductAttention(计算 context vector)
2.5 multiheadattention
2.6 feedforward layer
2.7 encoder layer
2.8 encoder
2.9 decoder layer
2.10 decoder
3. transformer
4. 模型、损失函数、优化器
5. 训练
6. 测试
正文:
一、Transformer结构展开图
1.原图
1.在原论文中N=6,也就是分别有6个Encoder和Decoder。
2.原论文每一个Decoder的enc_inputs都是最后一个Encoder的输出。如下图:
2.结构展开图
将N=6带入到上图中,得到Transformer结构展开图。原论文每一个Decoder的enc_inputs都是最后一个Encoder的输出。如下图: