BERT 源码初探之 run_pretraining.py
本文源码来源于 Github上的BERT 项目中的 run_pretraining.py 文件。阅读本文需要对Attention Is All You Need以及BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding两篇论文有所了解,以及部分关于深度学习、自然语言处理和Tensorflow的储备知识。
0 前言
- 关于Tensorflow:本文基于谷歌官方在GitHub上公布的BERT预训练模型,基于Tensorflow 1.13.1 运行。有关Tensorflow的部分建议参照官方网站。
- 关于Transformer:Transformer是Google提出的一种完全基于注意力机制的模型,想要对齐进行了解请参照官方论文Attention Is All You Need或者我的另一篇博客Transformer 学习笔记。
- 关于BERT:BERT也是Google提出的一个基于Transformer的预训练网络模型,更多和该模型有关的内容请参照官方论文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding、官方代码实现Github上的BERT以及我的另一篇博客BERT 学习笔记。
1 简介
略。
2 源码解释
2.1 参数设置
2.1.1 必须参数
flags.DEFINE_string(
"bert_config_file", None,
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
flags.DEFINE_string(
"input_file", None,
"Input TF example files (can be a glob or comma separated).")
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
- BERT 的 JSON 格式的配置文件的路径
- 输入文件
- 输出目录
2.2.2 其他参数
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded. Must match data generation.")
- 用于初始化的检查点
- 最大句子长度
flags.DEFINE_integer(
"max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence. "
"Must match data generation.")
每个句子的最大 MLM 预测数,必须和数据匹配。关于 MLM 模型,详情请参照 BERT论文。
flags.DEFINE_bool("do_train", False, "Whether to run training.")
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
flags.DEFINE_integer("save_checkpoints_steps", 1000,
"How often to save the model checkpoint.")
flags.DEFINE_integer("iterations_per_loop", 1000,
"How many steps to make in each estimator call.")
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
- 是否进行训练
- 是否在验证集上进行验证
- 训练批大小
- 验证批大小
- 初始化学习率
- 训练步数
- warmup步数
- 保存checkpoint的间隔
- 每隔多少步进行一次估计
- 评估步数的最大值
2.2.3 TPU相关
tf.flags.DEFINE_string(
"tpu_name", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", None,
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(