主要是留给自己以后回忆用的,写的不好,评论区不要开炮,Love&Peace。
源码:https://github.com/Kyubyong/transformer
本文参考博文内容:(细节参考在每一节)
NLP系列——Transformer源码解析(TensorFlow版)
论文解读:Attention is All you need
FlyAI小课堂:代码解读Transformer--Attention is All You Need
非常感谢以上大佬!
代码
hparams.py
import argparse class Hparams: parser = argparse.ArgumentParser() # prepro parser.add_argument('--vocab_size', default=32000, type=int) # train ## files parser.add_argument('--train1', default='iwslt2016/segmented/train.de.bpe', help="german training segmented data") parser.add_argument('--train2', default='iwslt2016/segmented/train.en.bpe', help="english training segmented data") parser.add_argument('--eval1', default='iwslt2016/segmented/eval.de.bpe', help="german evaluation segmented data") parser.add_argument('--eval2', default='iwslt2016/segmented/eval.en.bpe', help="english evaluation segmented data") parser.add_argument('--eval3', default='iwslt2016/prepro/eval.en', help="english evaluation unsegmented data") ## vocabulary parser.add_argument('--vocab', default='iwslt2016/segmented/bpe.vocab', help="vocabulary file path") # training scheme parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--eval_batch_size', default=128, type=int) parser.add_argument('--lr', default=0.0003, type=float, help="learning rate") parser.add_argument('--warmup_steps', default=4000, type=int) # 预热学习率 parser.add_argument('--logdir', default="log/1", help="log directory") # 日志存储路径 parser.add_argument('--num_epochs', default=20, type=int) parser.add_argument('--evaldir', default="eval/1", help="evaluation dir") # model parser.add_argument('--d_model', default=512, type=int, help="hidden dimension of encoder/decoder") # 词嵌入维度 parser.add_argument('--d_ff', default=2048, type=int, help="hidden dimension of feedforward layer") # 前向传播网络隐层单元数量 parser.add_argument('--num_blocks', default=6, type=int, help="number of encoder/decoder blocks") # blocks的数量 parser.add_argument('--num_heads', default=8, type=int, help="number of attention heads") # 多头注意力 “头”的数量 parser.add_argument('--maxlen1', default=100, type=int, help="maximum length of a source sequence") # 源句最大长度 parser.add_argument('--maxlen2', default=100, type=int, help="maximum length of a target sequence") # 目标句最大长度 parser.add_argument('--dropout_rate', default=0.3, type=float) # dropout丢弃概率 parser.add_argument('--smoothing', default=0.1, type=float, help="label smoothing rate") # 平滑率 # test parser.add_argument('--test1', default='iwslt2016/segmented/test.de.bpe', help="german test segmented data") parser.add_argument('--test2', default='iwslt2016/prepro/test.en', help="english test data") parser.add_argument('--ckpt', help="checkpoint file path") # 保存checkpoint的地址 parser.add_argument('--test_batch_size', default=128, type=int) parser.add_argument('--testdir', default="test/1", help="test result dir")
定义了一些与训练、词汇表、模型、测试相关的超参数。
argparse库
用于命令项选项与参数解析的模块。
一般为三个步骤:
- 创建 ArgumentParser() 对象
- 调用 add_argument() 方法添加参数
- 使用 parse_args() 解析添加的参数
train.py
import tensorflow as tf
from model import Transformer
from tqdm import tqdm
from data_load import get_batch
from utils import save_hparams, save_variable_specs, get_hypotheses, calc_bleu
import os
from hparams import Hparams
import math
import logging
logging.basicConfig(level=logging.INFO)
logging日志库
记录日志信息,见参考链接。
参考:logging的简单介绍
读取超参数
logging.info("# hparams")
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
save_hparams(hp, hp.logdir)
利用Hparmas类实例化一个对象,获取其中参数,并将参数信息写为日志保存到logdir路径中。
准备训练/评估的批数据
logging.info("# Prepare train/eval batches")
train_batches, num_train_batches, num_train_samples = get_batch(hp.train1, hp.train2,
hp.maxlen1, hp.maxlen2,
hp.vocab, hp.batch_size,
shuffle=True)
eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1, hp.eval2,
100000, 100000,
hp.vocab, hp.batch_size,
shuffle=False)
# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)
xs, ys = iter.get_next()
train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)
调用data_load中的get_batch函数,得到batch数据。
使用给定结构创建一个新的未初始化的迭代器Iterator,且未绑定到特定的数据集。
后续使用make_initializer()绑定特定数据集。
使用模型进行训练与评估
logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train(xs, ys)
y_hat, eval_summaries = m.eval(xs, ys)
向Transformer类中传递hp参数,实例化出模型对象。
用模型中的方法,根据数据集进行训练、评估。
训练
with tf.Session() as sess:
ckpt = tf.train.latest_checkpoint(hp.logdir) # 查找最新保存的checkpoint文件,读取模型保存好的参数
if ckpt is None: # 可能没有检查点
logging.info("Initializing from scratch") # 日志记录,从头开始初始化
sess.run(tf.global_variables_initializer()) # 初始化变量
save_variable_specs(os.path.join(hp.logdir, "specs")) # 存储变量相关的信息,如变量名、大小、参数数量等
else:
saver.restore(sess, ckpt) # 有检查点的话,恢复保存先前的变量,不必从头初始化
summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph) # 保存训练过程数据的实例
sess.run(train_init_op) # 开始训练
total_steps = hp.num_epochs * num_train_batches # 训练需要的循环次数
_gs = sess.run(global_step)
for i in tqdm(range(_gs, total_steps+1)): # 训练
_, _gs, _summary = sess.run([train_op, global_step, train_summaries])
epoch = math.ceil(_gs / num_train_batches) # 向上取整,计算epoch
summary_writer.add_summary(_summary, _gs) # 保存训练过程数据
if _gs and _gs % num_train_batches == 0: # 根据当前进度,记录日志信息
logging.info("epoch {} is done".format(epoch)) # 代数
_loss = sess.run(loss) # train loss # 计算训练损失
logging.info("# test evaluation")
_, _eval_summaries = sess.run([eval_init_op, eval_summaries]) # 评估效果
sum