BERT 源码初探之 create_pretraining_data.py
本文源码来源于 Github上的BERT 项目中的 run_pretraining.py 文件。阅读本文需要对Attention Is All You Need以及BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding两篇论文有所了解,以及部分关于深度学习、自然语言处理和Tensorflow的储备知识。
0 前言
- 关于Tensorflow:本文基于谷歌官方在GitHub上公布的BERT预训练模型,基于Tensorflow 1.13.1 运行。有关Tensorflow的部分建议参照官方网站。
- 关于Transformer:Transformer是Google提出的一种完全基于注意力机制的模型,想要对齐进行了解请参照官方论文Attention Is All You Need或者我的另一篇博客Transformer 学习笔记。
- 关于BERT:BERT也是Google提出的一个基于Transformer的预训练网络模型,更多和该模型有关的内容请参照官方论文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding、官方代码实现Github上的BERT以及我的另一篇博客BERT 学习笔记。
1 简介
要使用时才发现 BERT 提供了把文本数据转化为预训练模型所需的数据的代码,因此本文就来阅读这一部分代码吧。
2 源码解释
2.1 参数定义
2.1.1 必须参数
flags.DEFINE_string("input_file", None,
"Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
"output_file", None,
"Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
- 文件输入路径
- 输出文件路径
- 词典文件路径
2.2.2 可选参数
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
- 是否小写输入
- 最大句子的长度
- 每一句MLM预测的百分比
- 随机数种子(用于数据生成)
flags.DEFINE_integer(
"dupe_factor", 10,
"Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
"short_seq_prob", 0.1,
"Probability of creating sequences which are shorter than the "
"maximum length.")
- 复制输入数据的次数(采用不同的masks)
- MLM的比例
- 生成小于最大长度的句子的概率
2.2 训练实例
2.2.1 一个单独的训练实例(TrainingInstance)
class TrainingInstance(object):
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens
self.segment_ids = segment_ids
self.is_