![288c4f649738d5e172a152ff1c646206.png](https://i-blog.csdnimg.cn/blog_migrate/9d984a0dc5c7266c373080b668b5261f.jpeg)
本文主要会阅读bert源码(https://github.com/google-research/bert )中的create_pretraining_data.py文件,已完成modeling.py、optimization.py、run_pretraining.py、tokenization.py文件的源码阅读,后续会陆续阅读bert的特征抽取、下游任务训练等源码。本文大体以深度调用优先的顺序介绍了create_pretraining_data.py中的各个函数,主体分成样本生成和样本保存两个部分,样本生成中涉及全词遮蔽、下一句选取、遮蔽token的选择与生成等。
实战系列篇章中主要会分享,解决实际问题时的过程、遇到的问题或者使用的工具等等。如问题分解、bug排查、模型部署等等。相关代码实现开源在:https://github.com/wellinxu/nlp_store ,更多内容关注知乎专栏(或微信公众号):NLP杂货铺。
![8b55cb4f32c5d50a4dfaec9460397c47.png](https://i-blog.csdnimg.cn/blog_migrate/48b07ffc1f7d22cdf0f525556d7158bd.png)
- 运行参数
- 训练实例类
- main函数
- 样本生成
- create_training_instances
- create_instances_from_document
- create_masked_lm_predictions
- 训练样本写入文件
运行参数
生成训练样本的时候,需要提供相关参数,必要参数包括:输入文本路径、输出文件路径、词表文件路径。其他还有些默认但重要的参数,包括:允许最大序列长度(可以控制计算量,但也同时限制了下游任务中的输入长度),文档重复使用次数(重复利用输入文档生成训练样本,可以充分利用数据),生成比max_seq_length更短的句子的概率(为了减少预训练与微调时句子长度不一致的问题),是否进行全词遮蔽等等。
- input_file:必要参数,输入文本路径
- output_file:必要参数,输出文件路径
- vocab_file:必要参数,词表文件路径
- do_lower_case:字符是否进行小写化处理,bool,默认True
- do_whole_word_mask:是否进行全词遮蔽,bool,默认False
- max_seq_length:允许最大序列长度,int,默认128,bert_base中是512
- max_predictions_per_seq:允许被遮蔽token的最大数量,int,默认20
- random_seed:随机生成器的种子,int,默认12345
- dupe_factor:输入文档被重复使用次数,int,默认10
- masked_lm_prob:每个token被遮蔽的概率,float,默认15%
- short_seq_prob:生成比max_seq_length更短的句子的概率,float,默认10%
训练实例类
TrainingInstance是生成训练样本的基本定义,包含了部分遮蔽后的token序列,token属于第一句还是第二句的二值序列,下一句是否随机选择的bool值,以及被遮蔽的位置和真实token值。
class TrainingInstance(object):
"""单个训练实例(句子对)"""
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens # 替换遮蔽词token之后的所有token
self.segment_ids = segment_ids # 属于第一句还是第二句
self.is_random_next = is_random_next # 是否是随机下一句,“下一句预测”任务的标签
self.masked_lm_positions = masked_lm_positions # 被遮蔽的位置
self.masked_lm_labels = masked_lm_labels # 被遮蔽词的真实token
def __str__(self): # 重写str()方法
s = ""
s += "tokens: %sn" % (" ".join(
[tokenization.printable_text(x) for x in self.tokens]))
s += "segment_ids: %sn" % (" ".join([str(x) for x in self.segment_ids]))
s += "is_random_next: %sn" % self.is_random_next
s += "masked_lm_positions: %sn" % (" ".join(
[str(x) for x in self.masked_lm_positions]))
s += "masked_lm_labels: %sn" % (" ".join(
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
s += "n"
return s
def __repr__(self):
return self.__str__()
main函数
运行文件时,会先标记下必要参数,然后就开始运行主体函数:
if __name__ == "__main__":
# 必要参数
flags.mark_flag_as_required("input_file")
flags.mark_flag_as_required("output_file")
flags.mark_flag_as_required("vocab_file")
tf.app.run() # 运行main()函数
main函数中的逻辑比较简单,先后解析了输入输出路径,基于此就两步操作,create_training_instances(从行文本数据中创建训练样本)、write_instance_to_example_files(根据训练样本,写入TF样本文件)。
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# token切分类
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
input_files = [] # 输入文本路径list
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
tf.logging.info("*** Reading from input files ***")
for input_file in input_files:
tf.logging.info(" %s", input_file)
rng = random.Random(FLAGS.random_seed) # 随机数生成器
# 从行文本数据中创建训练样本
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng)
output_files = FLAGS.output_file.split(",") # 输出路径list
tf.logging.info("*** Writing to output files ***")
for output_file in output_files:
tf.logging.info(" %s", output_file)
# 根据训练样本,写入TF样本文件
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files)
样本生成
![73bef74c24df7b1b264cd45ef1c545e4.png](https://i-blog.csdnimg.cn/blog_migrate/8eb2093d5bd01c626881f33d67930fcb.jpeg)
输入的文本有两个格式要求:一句一行,真正的句子边界,在“下一句预测”任务中需要用到;文档之间用空行分割。具体示例如下:
这是第一篇文档的第一句。
这是第二句!
这是结束句。
这是第二篇文档的开头句。
这是中间句。
这是最后一句。
这是第三篇文档的开始句。
...
create_training_instances
该函数会从行文本数据中创建训练样本,具体伪代码逻辑如下:
> 1. 初始化文档集合与最终输出样本集合
> 2. 按行读取输入文本文件:
> 3. 如果当前是空行:
> 4. 则添加一个新文档(空行用来分割文档)
> 5. 否则:
> 6. 将行数据token化之后,添加到上一个文档里
> 7. 移除空文档,并随机打乱文档顺序
> 8.
> 9. 进行dupe_factor(默认10次)次循环:
> 10. 对每一个文档:
> 11. ## 下面小节中会具体介绍根据单个文档的样本生成方法
> 12. 生成一批训练样本[TrainingInstance],并添加到输入集合中。
> 13. 随机打乱样本顺序
> 14. 输出训练样本集合
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
"""
从行文本数据中创建训练样本
:param input_files: 输入行文件路径集合,[str]
:param tokenizer: token切分类,FullTokenizer
:param max_seq_length: 允许的最大序列长度,int
:param dupe_factor: 数据重复使用次数,默认10次,int
:param short_seq_prob: 生成短句子的概率,默认10%,float
:param masked_lm_prob: 被遮蔽的概率,默认15%, float
:param max_predictions_per_seq: 允许最多被遮蔽的token数目,int
:param rng: 随机数生成器
:return: 训练样本list,[TrainingInstance]
"""
all_documents = [[]] # 外层索引表示每篇文章,内层索引表示每句话
# 输入文本格式:
#(1)一句一行。需要是正真的一句,不能是一整个段落,也不能是文本的某个截断。
# 因为在“下一句预测”任务中需要用到句子边界。
# (2)文档之间用空行分割。这样“下一句预测”任务就不会跨越两个文档。
# 读取文本数据并切分成token
for input_file in input_files:
with tf.gfile.GFile(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
# 空行用来分割文档
if not line:
all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
all_documents[-1].append(tokens)
# 移除空文档并打乱顺序
all_documents = [x for x in all_documents if x]
rng.shuffle(all_documents)
vocab_words = list(tokenizer.vocab.keys())
instances = []
# 多次重复数据,每次都会对数据进行不同的mask,充分利用训练数据
for _ in range(dupe_factor):
for document_index in range(len(all_documents)):
instances.extend(
# 根据单个文档生成训练样本
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
rng.shuffle(instances) # 打乱样本顺序
return instances
create_instances_from_document
这个函数会根据单个文档生成训练样本,具体处理逻辑有两个:一个是生成当前生成样本的目标长度,大部分情况下都会将训练样本填充至允许的最大长度,这样可以节省算力,但是这样预训练跟微调截断的文本长度就有明显的差异,为了减少这样的差距,在10%的情况下,只会生成一个更短的样本,其长度是2到最大长度之间的一个随机值;另一个是“下一句预测”任务中,第二句的选择与拼接。 伪代码逻辑如下:
> 1. 获取当前生成样本的目标长度:
> 2. 90%的情况下取允许的最大长度,10%的情况下随机选择2到最大长度之间的一个值。
> 3.
> 4. 初始化输出样本集合与当前token集合current_chunk。
> 5. 循环当前文档中所有句子的索引值:
> 6. 将当前索引的句子token添加到current_chunk中。
> 7. 如果达到最后一句或者current_chunk中token的长度大于等于目标长度:
> 8. 随机选择current_chunk中的前几个句子作为样本的“”“A”句。
> 9. 如果current_chunk中只有一句或者50%的概率下:
> 10. 根据样本目标长度与“A”句长度,获取“B”句目标长度。
> 11. 从总文档集合中随机抽取另外一篇文档。
> 12. 从随机文档中随机选择一个句子作为起点,以“B”句目标长度为限,获取样本的“B”句。
> 13. 将当前索引退回到current_chunk中没有使用的句子索引上。
> 11. 否则:
> 12. 将current_chunk剩下的句子作为样本中的“B”句。
> 13. 将句子对截断到最大限制长度。
> 14. 添加开头的“[CLS]”,第一句结尾的“[SEP]”和第二句结尾的“[SEP]”字符。
> 15. ## 下面小节会详细讲解token遮蔽的方法
> 16. 根据遮蔽语言模型,随机遮蔽token,生成预测目标。
> 17. 获取当个样本实例,并添加到输出样本集合中。
> 18. 置空current_chunk。
> 19. 返回样本集合。
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
"""
根据单个文档生成训练样本
:param all_documents: 所有的文档list,用来随机抽取下一句使用的
:param document_index: 当前文档的索引,int
:param max_seq_length: 允许的最大序列长度,int
:param short_seq_prob: 生成短句子的概率,默认10%,float
:param masked_lm_prob: 被遮蔽的概率,默认15%, float
:param max_predictions_per_seq: 允许最多被遮蔽的token数目,int
:param vocab_words: 词表, list
:param rng: 随机数生成器, Random
:return: 单个训练样本,TrainingInstance
"""
document = all_documents[document_index]
#最大长度要给 [CLS](开头), [SEP](第一句结尾), [SEP](结尾)留三个位置
max_num_tokens = max_seq_length - 3
# 大多数情况下,我们都会填满整个句子,直到“max_seq_length”长度,因为短文本太浪费算力了。
# 但是,我们有时(short_seq_prob=10%的情况)需要用更短的句子来减少预训练与微调阶段之间的差距。
# “target_seq_length”只是一个粗略的长度目标,而“max_seq_length”是固定的限制。
target_seq_length = max_num_tokens
if rng.random() < short_seq_prob: # short_seq_prob=10%的情况下,句子长度会随机取一个小于最大长度的值
target_seq_length = rng.randint(2, max_num_tokens)
# 我们不会仅仅将文档中的tokens拼接到一个长句中,然后任意地选择一个切分点,因为这样会让“下一句预测”任
# 务太简单。实际上,我们会根据用户输入,根据实际的句子将输入分成“A”句跟“B”句。
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
a_end = 1 # current_chunk中句子放到“A”句(第一句)中的数量
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# 下一句是否随机选择
is_random_next = False
# current_chunk中只有一句,或者50%的概率,下一句进行随机选择
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a) # “B”句的最大长度限制
# 对大型语料来说,很少需要一次以上的迭代,但是为了小心起见,我们尽量保证
# 随机的文档不是当前正在处理的文档。
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1) # 随机选择一句作为"B"句的开始
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# 随机选择下一句的时候,current_chunk中有一些句子是没有用到过的,
# 为了不浪费语料,将这些句子“放回去”
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
else: # 真的下一句
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) # 将句子对截断到最大限制长度
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = [] # 原始句子对token
segment_ids = [] # 判断属于第一句还是第二句
tokens.append("[CLS]") # 添加开头token
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]") # 添加第一句结尾token
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]") # 添加结尾token
segment_ids.append(1)
# 根据遮蔽语言模型,生成预测目标
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
# 单个训练实例(句子对)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
is_random_next=is_random_next,
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
"""
将句子对截断到最大限制长度
:param tokens_a: “A”句token, list
:param tokens_b: “B”句token,list
:param max_num_tokens: 允许最大token数目, int
:param rng: 随机数生成器,Random
:return: 没有返回,原地修改
"""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
# 截断长度较长的一句
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# 为了避免偏差,50%的情况下我们丢弃开头的token,50%的情况下丢弃结尾的token
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
create_masked_lm_predictions
这个函数会根据遮蔽语言模型,随机遮蔽token,并生成预测目标。其中会涉及到全词遮蔽的思想,具体指:某个词片段被遮蔽了,则将此词片段所属的整个词都进行遮蔽。实际过程中,在token化的时候,将非开头的词片段加上“##”前缀,这样可以确定词边界来进行全词遮蔽。具体伪代码逻辑如下:
> 1. 初始化候选索引集合。
> 2. 循环遍历句子中每一个token:
> 3. 如果是“[CLS]”或“[SEP]”则跳过。
> 4. 如果要进行全词遮蔽且当前token有前缀“##”:
> 5. 则将当前token的索引添加到上一个候选集中。
> 6. 否则:
> 7. 以当前索引新建一个候选集,并添加到整体候选集合里。
> 8. 随机打乱候选索引集合。
> 9.
> 10. 获取随机的被遮蔽数量。
> 11. 初始化被遮蔽后的token序列以及被遮蔽的token序列。
> 12. 循环遍历所有的候选集合:
> 13. 如果遮蔽的数量足够,则跳出。
> 14. 如果加上当前候选集合中token的数量,超过限制,则忽略当前集合。
> 15. 如果集合中存在已经被遮蔽过的token,则忽略当前集合。
> 16. 循环遍历当前集合中每一个索引:
> 17. 当前索引的token以80%的概率变成“[MASK]”,10%的概率变为一个随机token,10%的概率保持不变。
> 18. 保存当前被遮蔽的索引以及真实token值。
> 19. 按顺序获取被遮蔽的token位置及真实值。
> 20. 返回被遮蔽后的token序列、被遮蔽的位置以及真实值。
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""
根据遮蔽语言模型,生成预测目标
:param tokens: 原始句子对token, list
:param masked_lm_prob: 被遮蔽的概率,默认15%, float
:param max_predictions_per_seq: 允许最多被遮蔽的token数目,int
:param vocab_words: 词表,list
:param rng: 随机数生成器,Random
:return: 被遮蔽后的句子对output_tokens,list
被遮蔽的位置masked_lm_positions, list
被遮蔽的真实tokenmasked_lm_labels, list
"""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# WWM全词遮蔽表示我们会遮蔽跟原词相关的所有词片段。当词被分层词片段,第一个token
# 没有任何标记,后面的每一个token都会加上“##”前缀。所以一旦看见带##的token,我
# 们会将它添加到前面一个词索引的集合里。
# 要注意到,全词遮蔽并没有改变训练代码,我们依然需要独立地预测每一个词片段,
# 需要计算整个词表的softmax。
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
rng.shuffle(cand_indexes) # 打乱候选索引集
output_tokens = list(tokens) # 被遮蔽后的句子对token
# 被遮蔽数量
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# 如果添加当前候选集后,遮蔽的数量超过了最大限制,则跳过此候选集
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False # 判断token是否出现错
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue # 如果token出现过,则跳过,为了尽量遮蔽不同的token
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80%会将token换成[MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10%会保持原词
if rng.random() < 0.5:
masked_token = tokens[index]
# 10%会替换成一个随机的词
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token # 将遮蔽位置的词替换掉
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = [] # 被遮蔽的位置
masked_lm_labels = [] # 被遮蔽词的真实token
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
训练样本写入文件
上面一小节已经得到了训练样本,但格式上还不能直接作为模型的 输入,本小节会将已有样本改成模型可以直接输入的格式,并写入到文件中。为了保持所有训练数据的一致性,会将每句的token id用0补全到最大允许长度,新增input_mask的二值列表,用来判断哪些是增补的,同样的被遮蔽序列也会用0补全到最大允许被遮蔽数目,新增masked_lm_weights的二值列表,来判断哪些是增补的。将补齐后的数据转换成TF Example格式(可参考NLP实战篇之tf2数据输入),然后写入文件中。
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files):
"""
根据训练样本,写入TF样本文件
:param instances: 训练样本list, [TrainingInstance]
:param tokenizer: token切分类,FullTokenizer
:param max_seq_length: 允许最大序列长度,int
:param max_predictions_per_seq: 允许最大遮蔽数目,int
:param output_files: TF exmaple输出文件路径, list
:return:
"""
writers = []
for output_file in output_files: # 文件会很大,所以会选择写入多个文件
writers.append(tf.python_io.TFRecordWriter(output_file))
writer_index = 0
total_written = 0
for (inst_index, instance) in enumerate(instances):
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) # token转为id
input_mask = [1] * len(input_ids) # 有实际输入的位置为1,没有实际输入的后续用0pad
segment_ids = list(instance.segment_ids)
assert len(input_ids) <= max_seq_length
# 将输入序列长度padding到“max_seq_length”,用0padding
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = list(instance.masked_lm_positions) # 被遮蔽的位置
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) # 被遮蔽token的id
masked_lm_weights = [1.0] * len(masked_lm_ids) # 实际遮蔽的权重为1,没有实际遮蔽的权重后续用0pad
# 将遮蔽序列长度padding到“max_predictions_per_seq”,用0padding
while len(masked_lm_positions) < max_predictions_per_seq:
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance.is_random_next else 0 # 下一句预测的标签
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids) # 输入的id, [max_seq_length]
features["input_mask"] = create_int_feature(input_mask) # 输入的mask, [max_seq_length]
features["segment_ids"] = create_int_feature(segment_ids) # 第一句、第二句, [max_seq_length]
features["masked_lm_positions"] = create_int_feature(masked_lm_positions) # 语言模型中被遮蔽的位置, [max_predictions_per_seq]
features["masked_lm_ids"] = create_int_feature(masked_lm_ids) # 遮蔽语言模型的标签, [max_predictions_per_seq]
features["masked_lm_weights"] = create_float_feature(masked_lm_weights) # 遮蔽语言模型中被遮蔽的标签的权重, [max_predictions_per_seq]
features["next_sentence_labels"] = create_int_feature([next_sentence_label]) # 下一句预测的标签, [1]
# 转为TF example
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writers[writer_index].write(tf_example.SerializeToString()) # 写入文件
writer_index = (writer_index + 1) % len(writers)
total_written += 1
if inst_index < 20: # 展示前20个样本
tf.logging.info("*** Example ***")
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in instance.tokens]))
for feature_name in features.keys():
feature = features[feature_name]
values = []
if feature.int64_list.value:
values = feature.int64_list.value
elif feature.float_list.value:
values = feature.float_list.value
tf.logging.info(
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
for writer in writers:
writer.close()
tf.logging.info("Wrote %d total instances", total_written)
def create_int_feature(values):
# int类型数据转换
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
def create_float_feature(values):
# float类型数据转换
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return feature