机器学习离不开实践的验证,推荐大家可以多在FlyAI竞赛服务平台多参加训练和竞赛,以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。
目录
- 代码结构
- 调用模型前的设置模块(hparams.py,prepro.py,data_load.py,utils.py)
- transformer代码解析(modules.py , model.py )
- 训练和测试(train.py,eval.py和test.py )
一、代码结构
论文主题模块
该实现【1】相对原始论文【2】有些许不同,比如为了方便使用了IWSLT 2016德英翻译的数据集,直接用的positional embedding,把learning rate一开始就调的很小等等,不过大同小异,主要模型没有区别.(另外注意最下方的outputs为上面翻译后的词或字)
该实现一共包括以下几个文件
介绍:
download.sh:一个下载 IWSLT 2016 的脚本,需要在Linux环境下,或git bash插件内运行。
hparams.py:该文件包含所有需要用到的参数
prepro.py:从源文件中,提取生成源语言和目标语言的词汇文件。
data_load.py: 该文件包含所有关于加载数据以及批量化数据的函数。
model.py :encode和decode的模型架构,基本是调用到mudules的数据包
modules.py :网络的模型组件,比如FFN,masked mutil-head attetion,LN,Positional Encoding等
train.py :训练模型的代码,定义了模型,损失函数以及训练和保存模型的过程,包含了评估模型的效果
test.py :测试模型
utils.py:调用到的工具类代码,起到一些协助的作用
二、调用模型前的设置模块(hparams.py,prepro.py,data_load.py,utils.py)
2.1 hparams.py
该实现所用到的所又的超参数都在这个文件里面。以下是该文件的所有代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
parser = argparse.ArgumentParser() parser.add_argument定义变量成为一种趋势;help就是该参数的解释,故基本能够理解参含义,这里挑几个介绍下:
batch_size的大小以及初始学习速率还有日志的目录,batch_size 在后续代码中即所谓的N,参数中常会见到。最后定义了一些模型相关的参数,
maxlen1/2为一句话里最大词的长度为100个,在其他代码中就用的是T来表示,你也可以根据自己的喜好将这个参数调大;
num_epochs被设置为20,该参数表示所有出现次数少于num_epochs次的都会被当作UNK来处理;
hidden_units设置为512,隐藏节点的个数;
num_blocks:重复模块的数量,这里默认为6个
num_heads:multi-head attention 中用到的切分的头的数量
2.2 prepro.py
根据iwslt2016的原始数据,做一个预处理,放到prepro和segment文件中。
2.3 data_load.py
词和序号的转换,源目的词列表,字符串转数字,迭代返回预估值,加载数据以及批量化数据的函数等。
2.4 utils.py
计算计算batches数量def calc_num_batches(total_num, batch_size),
整数tensor转换成string tensor: def convert_idx_to_token_tensor(inputs, idx2token);
处理转换输出def postprocess(hypotheses, idx2token);
保存参数到路径def save_hparams(hparams, path)
加载参数:def load_hparams(parser, path)
保存有关变量的信息,例如它们的名称、形状和参数总数fpath:字符串。输出文件路径:def save_variable_specs(fpath)
得到假设。num_batches、num_samples:标量。sess:对象张量:要获取的目标张量dict:idx2token字典def get_hypotheses(num_batches, num_samples, sess, tensor, dict):
计算bleu(要调用perl文件,我windons进行了删除才跑过去):def calc_bleu(ref, translation)
三、transformer代码解析(modules.py , model.py )
这是最为主要的部分
3.1 modules.py
3.1.1 layer normalization
是 3.1.5 multihead_attention的子模块。归一化数据的一种方式,不过LN 是在每一个样本上计算均值和方差,有点类似CV用到的instance normalization而不是BN那种在批方向计算均值和方差!公式如下:
具体代码如下,后面被用在很多模块中都有使用
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
|
3.1.2 get_token_embeddings
初始化嵌入向量,用矩阵表示,目前为[vocab_size, num_units]的矩阵,vocab_size为词的数量,num_units为embedding size,一般根据论文中的设置,为512
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
|
3.1.3 scaled_dot_product_attention
缩放的点积注意力机制,是 3.1.5 multihead_attention的子模块
本文attention公式如上所示,和dot-product attention除了没有使用缩放因子,其他和这个一样。
additive attention和dot-product(multi-plicative) attention是最常用的两个attention 函数。为何选择上面的呢?主要为了提升效率,兼顾性能。
效率:在实践中dot-product attention要快得多,而且空间效率更高。这是因为它可以使用高度优化的矩阵乘法代码来实现。
性能:较小时,这两种方法性能表现的相近,当比较大时,addtitive attention表现优于 dot-product attention(点积在数量级上增长的幅度大,将softmax函数推向具有极小梯度的区域 )。所以加上因子拉平性能。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|
3.1.4 mask
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
|
3.1.5 multihead_attention(重要组件)
该部分代码包含了整体框架中的这个部分:
主要公式如下:
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|
3.1.6 ff——feed forward(重要组件)
除了attention子层之外,编码器和解码器中的每个层都包含一个完全连接的前馈网络,该网络分别相同地应用于每个位置。 该前馈网络包括两个线性变换,并在第一个的最后使用ReLU激活函数,公式表示如下:
不同position的FFN是一样的,但是不同层是不同的。
描述这种情况的另一种方式是两个内核大小为1的卷积。输入和输出的维度是,内层的维度=2048。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
|
3.1.7 label_smoothing
这部分就相当于是使矩阵中的数进行平滑处理。把0改成一个很小的数,把1改成一个比较接近于1的数。论文中说这虽然会使模型的学习更加不确定性,但是提高了准确率和BLEU score。(论文中的5.4部分)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
|
3.1.8 positional_encoding(重要模块)
该模块内如为:
位置编码,论文中3.5的内容,公式如下
其中pos是指当前词在句子中的位置,i是指向量中每个值的维度,位置编码的每个维度对应于正弦曲线。我们选择了这个函数,因为我们假设它允许模型容易地学习相对位置。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
|
3.1.9 noam_scheme
学习率的衰减,在3.2的train(self, xs, ys)模块中会用到,代码如下:
1 2 3 4 5 6 7 8 9 |
|
3.2 model.py
3.2.1 导入包
1 2 3 4 5 6 7 8 |
|
简单介绍:上面提到的大部分模块,基本都在这里用到。
3.2.2 创建Transformer类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
|
类内函数encode部分:按照下图部分一步一步进行。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
|
类内函数decode部分:按照下图部分一步一步进行。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
模型训练部分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
|
模型训验证部分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
|
四、训练和测试(train.py,eval.py和test.py)
4.1 train.py
训练模型的代码,定义了模型,损失函数以及训练和保存模型的过程 包含了评估模型的效果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
|
4.2 test.py
测试模型,若是windowns上跑模型,注意将最后两句计算Bleu部分注释掉。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
|
结果
参考文献
【1】本次实现解读的代码:https://github.com/Kyubyong/transformer
【2】论文:https://arxiv.org/abs/1706.03762
【3】老版本代码解读 https://blog.csdn.net/mijiaoxiaosan/article/details/74909076
————————————————
更多精彩内容请访问FlyAI-AI竞赛服务平台;为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台;每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。
挑战者,都在FlyAI!!!