BERT源码地址:https://github.com/google-research/bert
学习模块:create_pretraining_data
功能:定义了如何将普通文本转换成可用于预训练BERT模型的tfrecord文件的方法。
目录
01 执行参数及意义:
--input_file=./sample_text.txt \ #训练文本
--output_file=/tmp/tf_examples.tfrecord \ #生成记录,指定在tmp下生成
--vocab_file=$BERT_BASE_DIR/vocab.txt \ #需要提前准备的词表
--do_lower_case=True \ #是否小写输入, 默认True--max_seq_length=128 \ #限制输入中最大句子的长度
--max_predictions_per_seq=20 \ #每一句MLM预测的百分比
--masked_lm_prob=0.15 \ #掩码语言模型的比例
--random_seed=12345 \ #用于数据生成的随机种子--dupe_factor=5 #复制输入数据的次数(使用不同的掩码), 默认循环10次
源代码:
必要参数:
flags.DEFINE_string("input_file", None,
"Input raw text file (or comma-separated list of files).")flags.DEFINE_string(
"output_file", None,
"Output TF example file (or comma-separated list of files).")flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")可选参数
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")flags.DEFINE_bool(
"do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.")flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence.")flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer(
"dupe_factor", 10,
"Number of times to duplicate the input data (with different masks).")flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
"short_seq_prob", 0.1,
"Probability of creating sequences which are shorter than the "
"maximum length.")
02 代码解析
单独实例训练
#类定义 class TrainingInstance(object): """A single training instance (sentence pair).""" 初始化 def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next): self.tokens = tokens self.segment_ids = segment_ids self.is_random_next = is_random_next self.masked_lm_positions = masked_lm_positions self.masked_lm_labels = masked_lm_labels #字符串化 def __str__(self): s = "" s += "tokens: %s\n" % (" ".join( [tokenization.printable_text(x) for x in self.tokens])) s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) s += "is_random_next: %s\n" % self.is_random_next s += "masked_lm_positions: %s\n" % (" ".join( [str(x) for x in self.masked_lm_positions])) s += "masked_lm_labels: %s\n" % (" ".join( [tokenization.printable_text(x) for x in self.masked_lm_labels])) s += "\n" return s def __repr__(self): return self.__str__()