BERT 源码初探之 run_classifier.py
本文源码来源于 Github上的BERT 项目中的 run_classifier.py 文件。阅读本文需要对Attention Is All You Need以及BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding两篇论文有所了解,以及部分关于深度学习、自然语言处理和Tensorflow的储备知识。
0 前言
- 关于Tensorflow:本文基于谷歌官方在GitHub上公布的BERT预训练模型,基于Tensorflow 1.13.1 运行。有关Tensorflow的部分建议参照官方网站。
- 关于Transformer:Transformer是Google提出的一种完全基于注意力机制的模型,想要对齐进行了解请参照官方论文Attention Is All You Need或者我的另一篇博客Transformer 学习笔记。
- 关于BERT:BERT也是Google提出的一个基于Transformer的预训练网络模型,更多和该模型有关的内容请参照官方论文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding、官方代码实现Github上的BERT以及我的另一篇博客BERT 学习笔记。
1 简介
2 classifier使用方法
2.1 参数设置
2.1.1 必须参数
源码中提到的必须要设置的参数如下所示:
flags.DEFINE_string(
"data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
- 名称:data_dir
- 默认值: None
- 注释:输入数据的目录,需要包含任务所需的 .csv 格式(或者其他类型)的文件。
flags.DEFINE_string(
"bert_config_file", None,
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
- 名称:bert_config_file
- 默认值:None
- 注释:和预训练过的bert模型对应的json格式的配置文件,这指定了模型的架构。
flags.DEFINE_string("task_name", None, "The name of the task to train.")
- 名称:task_name
- 默认值:None
- 注释:要训练的任务名称。
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
- 名称:vocab_file
- 默认值:None
- 注释:BERT模型训练所用的词典文件
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
- 名称:output_dir
- 默认值:None
- 注释:模型checkpoint会被保存到的路径位置。
2.1.2 其他参数
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")
- 名称:init_checkpoint
- 默认值:None
- 注释:初始的checkpoint(通常来源于一个预训练过的BERT模型)。
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
- 名称:do_lower_case
- 默认值:True
- 注释:是否把输入的文本全部小写。对于所有的 uncased models(不保留大小写以及重音标记)应该设为True,而对于 cased models(保留大小写以及重音标记,如果你认为这对于你的训练任务是有益的)则应该设为False。
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
- 名称:max_seq_length
- 默认值:128
- 注释:WordPiece tokenization处理之后的最大输入序列长度。比这个长度长的输入序列将被阶段,而比这个短的输入序列将被补齐。
flags.DEFINE_bool("do_train", False, "Whether to run training.")
- 名称:do_train
- 默认值:False
- 注释:是否运行训练
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
- 名称:do_eval
- 默认值:False
- 注释:是否在验证集上进行计算(eval)。
flags.DEFINE_bool(
"do_predict", False,
"Whether to run the model in inference mode on the test set.")
- 名称:do_predict
- 默认值:False
- 注释:是否在测试集上用推断模式来运行模型。
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
- 名称:train_batch_size
- 默认值:32
- 注释:训练过程中的批尺寸。
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
- 名称:eval_batch_size
- 默认值:8
- 注释:验证过程中的批尺寸。
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
- 名称:predict_batch_size
- 默认值:8
- 注释:预测过程中的批尺寸。
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
- 名称:learning_rate
- 默认值:5e-5
- 注释:Adam算法的初始学习率。
flags.DEFINE_float("num_train_epochs", 3.0,
"Total number of training epochs to perform.")
- 名称:num_train_epochs
- 默认值:3.0
- 注释:完成训练所需要遍历训练数据集的次数
flags.DEFINE_float(
"warmup_proportion", 0.1,
"Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10% of training.")
- 名称:warmup_proportion
- 默认值:0.1
- 注释:用于warm up的训练步数比例。其中0.1表示训练步数的10%。
flags.DEFINE_integer("save_checkpoints_steps", 1000,
"How often to save the model checkpoint.")
- 名称:save_checkpoints_steps
- 默认值:1000
- 注释:多久保存一次模型
flags.DEFINE_integer("iterations_per_loop", 1000,
"How many steps to make in each estimator call.")
- 名称:iterations_per_loop
- 默认值:1000
- 注释:每隔多少步来调用一次估计函数。
2.1.3 TPU相关参数
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
- 名称:use_tpu
- 默认值:False
- 注释:是否使用TPU
tf.flags.DEFINE_string(
"tpu_name", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
- 名称:tpu_name
- 默认值:None
- 注释:用于训练的云TPU,应该是创建云TPU时使用的名称,或者是类似 grpc://ip.address.of.tpu:8470 这种类型的URL。
tf.flags.DEFINE_string(
"tpu_zone", None,
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
- 名称:tpu_zone
- 默认值:None
- 注释:【可选】 云TPU所在的GCE区域,如果没有指定该参数,程序将会自动在元数据中检测。
tf.flags.DEFINE_string(
"gcp_project", None,
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
- 名称:gcp_project
- 默认值:None
- 注释:【可选】 云TPU项目的项目名称,如果不指定程序将自动从元数据中检测。
tf.flags.DEFINE_string("master", None, "[Optional