Transformer源码理解(Tensorflow)

主要是留给自己以后回忆用的,写的不好,评论区不要开炮,Love&Peace。

源码:https://github.com/Kyubyong/transformer

本文参考博文内容:(细节参考在每一节)

NLP系列——Transformer源码解析(TensorFlow版)

论文解读:Attention is All you need

图解Transformer(完整版)

Transformer解析与tensorflow代码解读

Transformer源码解读

FlyAI小课堂:代码解读Transformer--Attention is All You Need

非常感谢以上大佬!

代码

hparams.py

import argparse

class Hparams:
    parser = argparse.ArgumentParser()

    # prepro
    parser.add_argument('--vocab_size', default=32000, type=int)

    # train
    ## files
    parser.add_argument('--train1', default='iwslt2016/segmented/train.de.bpe',
                             help="german training segmented data")
    parser.add_argument('--train2', default='iwslt2016/segmented/train.en.bpe',
                             help="english training segmented data")
    parser.add_argument('--eval1', default='iwslt2016/segmented/eval.de.bpe',
                             help="german evaluation segmented data")
    parser.add_argument('--eval2', default='iwslt2016/segmented/eval.en.bpe',
                             help="english evaluation segmented data")
    parser.add_argument('--eval3', default='iwslt2016/prepro/eval.en',
                             help="english evaluation unsegmented data")

    ## vocabulary
    parser.add_argument('--vocab', default='iwslt2016/segmented/bpe.vocab',
                        help="vocabulary file path")

    # training scheme
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--eval_batch_size', default=128, type=int)

    parser.add_argument('--lr', default=0.0003, type=float, help="learning rate")
    parser.add_argument('--warmup_steps', default=4000, type=int) # 预热学习率
    parser.add_argument('--logdir', default="log/1", help="log directory") # 日志存储路径
    parser.add_argument('--num_epochs', default=20, type=int)
    parser.add_argument('--evaldir', default="eval/1", help="evaluation dir")

    # model
    parser.add_argument('--d_model', default=512, type=int,
                        help="hidden dimension of encoder/decoder") # 词嵌入维度
    parser.add_argument('--d_ff', default=2048, type=int,
                        help="hidden dimension of feedforward layer") # 前向传播网络隐层单元数量
    parser.add_argument('--num_blocks', default=6, type=int,
                        help="number of encoder/decoder blocks") # blocks的数量
    parser.add_argument('--num_heads', default=8, type=int,
                        help="number of attention heads") # 多头注意力 “头”的数量
    parser.add_argument('--maxlen1', default=100, type=int, 
                        help="maximum length of a source sequence") # 源句最大长度 
    parser.add_argument('--maxlen2', default=100, type=int, 
                        help="maximum length of a target sequence") # 目标句最大长度 
    parser.add_argument('--dropout_rate', default=0.3, type=float) # dropout丢弃概率
    parser.add_argument('--smoothing', default=0.1, type=float,
                        help="label smoothing rate") # 平滑率

    # test
    parser.add_argument('--test1', default='iwslt2016/segmented/test.de.bpe',
                        help="german test segmented data")
    parser.add_argument('--test2', default='iwslt2016/prepro/test.en',
                        help="english test data")
    parser.add_argument('--ckpt', help="checkpoint file path") # 保存checkpoint的地址
    parser.add_argument('--test_batch_size', default=128, type=int)
    parser.add_argument('--testdir', default="test/1", help="test result dir")

定义了一些与训练、词汇表、模型、测试相关的超参数。

argparse库

用于命令项选项与参数解析的模块。

一般为三个步骤:

  1. 创建 ArgumentParser() 对象
  2. 调用 add_argument() 方法添加参数
  3. 使用 parse_args() 解析添加的参数

参考:python学习笔记之argparse库的使用

train.py

import tensorflow as tf

from model import Transformer
from tqdm import tqdm
from data_load import get_batch
from utils import save_hparams, save_variable_specs, get_hypotheses, calc_bleu
import os
from hparams import Hparams
import math
import logging

logging.basicConfig(level=logging.INFO)

logging日志库

记录日志信息,见参考链接。

参考:logging的简单介绍

读取超参数

logging.info("# hparams")
hparams = Hparams() 
parser = hparams.parser
hp = parser.parse_args()
save_hparams(hp, hp.logdir)

利用Hparmas类实例化一个对象,获取其中参数,并将参数信息写为日志保存到logdir路径中。

准备训练/评估的批数据

logging.info("# Prepare train/eval batches")
train_batches, num_train_batches, num_train_samples = get_batch(hp.train1, hp.train2,
                                             hp.maxlen1, hp.maxlen2,
                                             hp.vocab, hp.batch_size,
                                             shuffle=True)
eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1, hp.eval2,
                                             100000, 100000,
                                             hp.vocab, hp.batch_size,
                                             shuffle=False)

# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)
xs, ys = iter.get_next()

train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

调用data_load中的get_batch函数,得到batch数据。

使用给定结构创建一个新的未初始化的迭代器Iterator,且未绑定到特定的数据集。

后续使用make_initializer()绑定特定数据集。

使用模型进行训练与评估

logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train(xs, ys)
y_hat, eval_summaries = m.eval(xs, ys)

向Transformer类中传递hp参数,实例化出模型对象。

用模型中的方法,根据数据集进行训练、评估。

训练

with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)    # 查找最新保存的checkpoint文件,读取模型保存好的参数
    if ckpt is None:    # 可能没有检查点
        logging.info("Initializing from scratch")    # 日志记录,从头开始初始化
        sess.run(tf.global_variables_initializer())    # 初始化变量
        save_variable_specs(os.path.join(hp.logdir, "specs"))    # 存储变量相关的信息,如变量名、大小、参数数量等
    else:
        saver.restore(sess, ckpt)    # 有检查点的话,恢复保存先前的变量,不必从头初始化

    summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)    # 保存训练过程数据的实例

    sess.run(train_init_op)     # 开始训练
    total_steps = hp.num_epochs * num_train_batches    # 训练需要的循环次数
    _gs = sess.run(global_step)
    for i in tqdm(range(_gs, total_steps+1)):    # 训练
        _, _gs, _summary = sess.run([train_op, global_step, train_summaries])
        epoch = math.ceil(_gs / num_train_batches)      # 向上取整,计算epoch
        summary_writer.add_summary(_summary, _gs)    # 保存训练过程数据

        if _gs and _gs % num_train_batches == 0:    # 根据当前进度,记录日志信息
            logging.info("epoch {} is done".format(epoch))    # 代数
            _loss = sess.run(loss) # train loss    # 计算训练损失

            logging.info("# test evaluation")
            _, _eval_summaries = sess.run([eval_init_op, eval_summaries])    # 评估效果
            sum
  • 0
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值