前言
在上一章【课程总结】day19(下):Transformer架构及注意力机制了解总结中,我们对Transformer架构以及注意力机制有了初步了解,本章将结合《The Annotated Transformer》的资料以及Transfomer_demo的示例代码,对Transformer的架构进行深入理解。
资料
Transformer_demo是一个使用Transformer架构进行翻译的示例代码,该工程与【课程总结】Day18:Seq2Seq的深入了解中的Seq2Seq实现非常类似,可以方便我们通过单步调试的方式,动态地了解Transformer架构。
- 代码地址:https://github.com/domonic18/transformer_demo.git
其他资料:《The Annotated Transformer》
- 博客地址:https://nlp.seas.harvard.edu/2018/04/03/attention.html
- Github仓库地址:https://github.com/harvardnlp/annotated-transformer
整体框架
代码分析理解
初始化流程
Transformer_demo的启动流程中:
- 第一步,在
main.py
中进行Translation()
初始化; - 第二步,调用分词器
tokenizer
提供的接口get_tokenizer()
进行字典的创建;(这部分逻辑实现与【课程总结】Day18:Seq2Seq的深入了解一致,所以这里不再赘述) - 第三步,调用
get_model()
创建Transformer模型; - 第四步,创建过程中,主要逻辑为:实例化创建
Encoder
对象和Decoder
对象。(实例化之前,提前准备了多头注意力对象attn
、前馈网络对象ff
以及位置编码对象position
) - 第五步,最后创建生成器
Genearator
对象。
对应源码:
def get_model(src_vocab,
tgt_vocab,
N=6,
d_model=512,
d_ff=2048,
h=8,
dropout=0.1):
"""
构建transformer模型
src_vocab: 源语言词典大小
tgt_vocab: 目标语言词典大小
N: 编码解码层数
d_model: 模型维度
d_ff: 前向传播层维度
h: 多头注意力的头数
dropout: 随机失活概率
"""
# 深拷贝函数
c = copy.deepcopy
attn = MultiHeadedAttention(h, d_model, dropout)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
position = PositionalEncoding(d_model, dropout)
model = EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff)