cannot create file怎么解决_NLP实战篇之bert源码阅读(create_pretraining_data)

288c4f649738d5e172a152ff1c646206.png

本文主要会阅读bert源码(https://github.com/google-research/bert )中的create_pretraining_data.py文件,已完成modeling.pyoptimization.pyrun_pretraining.pytokenization.py文件的源码阅读,后续会陆续阅读bert的特征抽取、下游任务训练等源码。本文大体以深度调用优先的顺序介绍了create_pretraining_data.py中的各个函数,主体分成样本生成和样本保存两个部分,样本生成中涉及全词遮蔽、下一句选取、遮蔽token的选择与生成等。

实战系列篇章中主要会分享,解决实际问题时的过程、遇到的问题或者使用的工具等等。如问题分解、bug排查、模型部署等等。相关代码实现开源在:https://github.com/wellinxu/nlp_store ,更多内容关注知乎专栏(或微信公众号):NLP杂货铺。

8b55cb4f32c5d50a4dfaec9460397c47.png
  • 运行参数
  • 训练实例类
  • main函数
  • 样本生成
    • create_training_instances
    • create_instances_from_document
    • create_masked_lm_predictions
  • 训练样本写入文件

运行参数

生成训练样本的时候,需要提供相关参数,必要参数包括:输入文本路径、输出文件路径、词表文件路径。其他还有些默认但重要的参数,包括:允许最大序列长度(可以控制计算量,但也同时限制了下游任务中的输入长度),文档重复使用次数(重复利用输入文档生成训练样本,可以充分利用数据),生成比max_seq_length更短的句子的概率(为了减少预训练与微调时句子长度不一致的问题),是否进行全词遮蔽等等。

  1. input_file:必要参数,输入文本路径
  2. output_file:必要参数,输出文件路径
  3. vocab_file:必要参数,词表文件路径
  4. do_lower_case:字符是否进行小写化处理,bool,默认True
  5. do_whole_word_mask:是否进行全词遮蔽,bool,默认False
  6. max_seq_length:允许最大序列长度,int,默认128,bert_base中是512
  7. max_predictions_per_seq:允许被遮蔽token的最大数量,int,默认20
  8. random_seed:随机生成器的种子,int,默认12345
  9. dupe_factor:输入文档被重复使用次数,int,默认10
  10. masked_lm_prob:每个token被遮蔽的概率,float,默认15%
  11. short_seq_prob:生成比max_seq_length更短的句子的概率,float,默认10%

训练实例类

TrainingInstance是生成训练样本的基本定义,包含了部分遮蔽后的token序列,token属于第一句还是第二句的二值序列,下一句是否随机选择的bool值,以及被遮蔽的位置和真实token值。

class TrainingInstance(object):
  """单个训练实例(句子对)"""

  def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
               is_random_next):
    self.tokens = tokens    # 替换遮蔽词token之后的所有token
    self.segment_ids = segment_ids    # 属于第一句还是第二句
    self.is_random_next = is_random_next    # 是否是随机下一句,“下一句预测”任务的标签
    self.masked_lm_positions = masked_lm_positions    # 被遮蔽的位置
    self.masked_lm_labels = masked_lm_labels    # 被遮蔽词的真实token

  def __str__(self):    # 重写str()方法
    s = ""
    s += "tokens: %sn" % (" ".join(
        [tokenization.printable_text(x) for x in self.tokens]))
    s += "segment_ids: %sn" % (" ".join([str(x) for x in self.segment_ids]))
    s += "is_random_next: %sn" % self.is_random_next
    s += "masked_lm_positions: %sn" % (" ".join(
        [str(x) for x in self.masked_lm_positions]))
    s += "masked_lm_labels: %sn" % (" ".join(
        [tokenization.printable_text(x) for x in self.masked_lm_labels]))
    s += "n"
    return s

  def __repr__(self):
    return self.__str__()

main函数

运行文件时,会先标记下必要参数,然后就开始运行主体函数:

if __name__ == "__main__":
  # 必要参数
  flags.mark_flag_as_required("input_file")
  flags.mark_flag_as_required("output_file")
  flags.mark_flag_as_required("vocab_file")
  tf.app.run()    # 运行main()函数

main函数中的逻辑比较简单,先后解析了输入输出路径,基于此就两步操作,create_training_instances(从行文本数据中创建训练样本)、write_instance_to_example_files(根据训练样本,写入TF样本文件)。

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  # token切分类
  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  input_files = []    # 输入文本路径list
  for input_pattern in FLAGS.input_file.split(","):
    input_files.extend(tf.gfile.Glob(input_pattern))

  tf.logging.info("*** Reading from input files ***")
  for input_file in input_files:
    tf.logging.info("  %s", input_file)

  rng = random.Random(FLAGS.random_seed)    # 随机数生成器
  # 从行文本数据中创建训练样本
  instances = create_training_instances(
      input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
      FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
      rng)

  output_files = FLAGS.output_file.split(",")    # 输出路径list
  tf.logging.info("*** Writing to output files ***")
  for output_file in output_files:
    tf.logging.info("  %s", output_file)

  # 根据训练样本,写入TF样本文件
  write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
                                  FLAGS.max_predictions_per_seq, output_files)

样本生成

73bef74c24df7b1b264cd45ef1c545e4.png

输入的文本有两个格式要求:一句一行,真正的句子边界,在“下一句预测”任务中需要用到;文档之间用空行分割。具体示例如下:

这是第一篇文档的第一句。
这是第二句!
这是结束句。

这是第二篇文档的开头句。
这是中间句。
这是最后一句。

这是第三篇文档的开始句。
...

create_training_instances

该函数会从行文本数据中创建训练样本,具体伪代码逻辑如下:

> 1. 初始化文档集合与最终输出样本集合
> 2. 按行读取输入文本文件:
> 3.      如果当前是空行:
> 4.          则添加一个新文档(空行用来分割文档)
> 5.      否则:
> 6.          将行数据token化之后,添加到上一个文档里
> 7.  移除空文档,并随机打乱文档顺序
> 8.    
> 9.  进行dupe_factor(默认10次)次循环:
> 10.     对每一个文档:
> 11.         ## 下面小节中会具体介绍根据单个文档的样本生成方法
> 12.         生成一批训练样本[TrainingInstance],并添加到输入集合中。
> 13. 随机打乱样本顺序
> 14. 输出训练样本集合

def create_training_instances(input_files, tokenizer, max_seq_length,
                              dupe_factor, short_seq_prob, masked_lm_prob,
                              max_predictions_per_seq, rng):
  """
  从行文本数据中创建训练样本
  :param input_files:  输入行文件路径集合,[str]
  :param tokenizer: token切分类,FullTokenizer
  :param max_seq_length: 允许的最大序列长度,int
  :param dupe_factor: 数据重复使用次数,默认10次,int
  :param short_seq_prob: 生成短句子的概率,默认10%,float
  :param masked_lm_prob: 被遮蔽的概率,默认15%, float
  :param max_predictions_per_seq: 允许最多被遮蔽的token数目,int
  :param rng: 随机数生成器
  :return: 训练样本list,[TrainingInstance]
  """
  all_documents = [[]]    # 外层索引表示每篇文章,内层索引表示每句话

  # 输入文本格式:
  #(1)一句一行。需要是正真的一句,不能是一整个段落,也不能是文本的某个截断。
  # 因为在“下一句预测”任务中需要用到句子边界。
  # (2)文档之间用空行分割。这样“下一句预测”任务就不会跨越两个文档。

  # 读取文本数据并切分成token
  for input_file in input_files:
    with tf.gfile.GFile(input_file, "r") as reader:
      while True:
        line = tokenization.convert_to_unicode(reader.readline())
        if not line:
          break
        line = line.strip()

        # 空行用来分割文档
        if not line:
          all_documents.append([])
        tokens = tokenizer.tokenize(line)
        if tokens:
          all_documents[-1].append(tokens)

  # 移除空文档并打乱顺序
  all_documents = [x for x in all_documents if x]
  rng.shuffle(all_documents)

  vocab_words = list(tokenizer.vocab.keys())
  instances = []
  # 多次重复数据,每次都会对数据进行不同的mask,充分利用训练数据
  for _ in range(dupe_factor):
    for document_index in range(len(all_documents)):
      instances.extend(
          # 根据单个文档生成训练样本
          create_instances_from_document(
              all_documents, document_index, max_seq_length, short_seq_prob,
              masked_lm_prob, max_predictions_per_seq, vocab_words, rng))

  rng.shuffle(instances)    # 打乱样本顺序
  return instances

create_instances_from_document

这个函数会根据单个文档生成训练样本,具体处理逻辑有两个:一个是生成当前生成样本的目标长度,大部分情况下都会将训练样本填充至允许的最大长度,这样可以节省算力,但是这样预训练跟微调截断的文本长度就有明显的差异,为了减少这样的差距,在10%的情况下,只会生成一个更短的样本,其长度是2到最大长度之间的一个随机值;另一个是“下一句预测”任务中,第二句的选择与拼接。 伪代码逻辑如下:

> 1. 获取当前生成样本的目标长度:
> 2.      90%的情况下取允许的最大长度,10%的情况下随机选择2到最大长度之间的一个值。
> 3. 
> 4. 初始化输出样本集合与当前token集合current_chunk。
> 5. 循环当前文档中所有句子的索引值:
> 6.      将当前索引的句子token添加到current_chunk中。
> 7.      如果达到最后一句或者current_chunk中token的长度大于等于目标长度:
> 8.          随机选择current_chunk中的前几个句子作为样本的“”“A”句。
> 9.          如果current_chunk中只有一句或者50%的概率下:
> 10.             根据样本目标长度与“A”句长度,获取“B”句目标长度。
> 11.             从总文档集合中随机抽取另外一篇文档。
> 12.             从随机文档中随机选择一个句子作为起点,以“B”句目标长度为限,获取样本的“B”句。
> 13.             将当前索引退回到current_chunk中没有使用的句子索引上。
> 11.         否则:
> 12.             将current_chunk剩下的句子作为样本中的“B”句。
> 13.     将句子对截断到最大限制长度。
> 14.     添加开头的“[CLS]”,第一句结尾的“[SEP]”和第二句结尾的“[SEP]”字符。
> 15.     ## 下面小节会详细讲解token遮蔽的方法
> 16.     根据遮蔽语言模型,随机遮蔽token,生成预测目标。
> 17.     获取当个样本实例,并添加到输出样本集合中。
> 18.     置空current_chunk。
> 19. 返回样本集合。

def create_instances_from_document(
    all_documents, document_index, max_seq_length, short_seq_prob,
    masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
  """
  根据单个文档生成训练样本
  :param all_documents: 所有的文档list,用来随机抽取下一句使用的
  :param document_index: 当前文档的索引,int
  :param max_seq_length: 允许的最大序列长度,int
  :param short_seq_prob: 生成短句子的概率,默认10%,float
  :param masked_lm_prob: 被遮蔽的概率,默认15%, float
  :param max_predictions_per_seq: 允许最多被遮蔽的token数目,int
  :param vocab_words: 词表, list
  :param rng: 随机数生成器, Random
  :return: 单个训练样本,TrainingInstance
  """
  document = all_documents[document_index]

  #最大长度要给 [CLS](开头), [SEP](第一句结尾), [SEP](结尾)留三个位置
  max_num_tokens = max_seq_length - 3

  # 大多数情况下,我们都会填满整个句子,直到“max_seq_length”长度,因为短文本太浪费算力了。
  # 但是,我们有时(short_seq_prob=10%的情况)需要用更短的句子来减少预训练与微调阶段之间的差距。
  # “target_seq_length”只是一个粗略的长度目标,而“max_seq_length”是固定的限制。
  target_seq_length = max_num_tokens
  if rng.random() < short_seq_prob:  # short_seq_prob=10%的情况下,句子长度会随机取一个小于最大长度的值
    target_seq_length = rng.randint(2, max_num_tokens)

  # 我们不会仅仅将文档中的tokens拼接到一个长句中,然后任意地选择一个切分点,因为这样会让“下一句预测”任
  # 务太简单。实际上,我们会根据用户输入,根据实际的句子将输入分成“A”句跟“B”句。
  instances = []
  current_chunk = []
  current_length = 0
  i = 0
  while i < len(document):
    segment = document[i]
    current_chunk.append(segment)
    current_length += len(segment)
    if i == len(document) - 1 or current_length >= target_seq_length:
      if current_chunk:
        a_end = 1   # current_chunk中句子放到“A”句(第一句)中的数量
        if len(current_chunk) >= 2:
          a_end = rng.randint(1, len(current_chunk) - 1)

        tokens_a = []
        for j in range(a_end):
          tokens_a.extend(current_chunk[j])

        tokens_b = []
        # 下一句是否随机选择
        is_random_next = False
        # current_chunk中只有一句,或者50%的概率,下一句进行随机选择
        if len(current_chunk) == 1 or rng.random() < 0.5:
          is_random_next = True
          target_b_length = target_seq_length - len(tokens_a)    # “B”句的最大长度限制

          # 对大型语料来说,很少需要一次以上的迭代,但是为了小心起见,我们尽量保证
          # 随机的文档不是当前正在处理的文档。
          for _ in range(10):
            random_document_index = rng.randint(0, len(all_documents) - 1)
            if random_document_index != document_index:
              break

          random_document = all_documents[random_document_index]
          random_start = rng.randint(0, len(random_document) - 1)    # 随机选择一句作为"B"句的开始
          for j in range(random_start, len(random_document)):
            tokens_b.extend(random_document[j])
            if len(tokens_b) >= target_b_length:
              break
          # 随机选择下一句的时候,current_chunk中有一些句子是没有用到过的,
          # 为了不浪费语料,将这些句子“放回去”
          num_unused_segments = len(current_chunk) - a_end
          i -= num_unused_segments
        else:    # 真的下一句
          is_random_next = False
          for j in range(a_end, len(current_chunk)):
            tokens_b.extend(current_chunk[j])
        truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)    # 将句子对截断到最大限制长度

        assert len(tokens_a) >= 1
        assert len(tokens_b) >= 1

        tokens = []    # 原始句子对token
        segment_ids = []    # 判断属于第一句还是第二句
        tokens.append("[CLS]")    # 添加开头token
        segment_ids.append(0)
        for token in tokens_a:
          tokens.append(token)
          segment_ids.append(0)

        tokens.append("[SEP]")     # 添加第一句结尾token
        segment_ids.append(0)

        for token in tokens_b:
          tokens.append(token)
          segment_ids.append(1)
        tokens.append("[SEP]")    # 添加结尾token
        segment_ids.append(1)

        # 根据遮蔽语言模型,生成预测目标
        (tokens, masked_lm_positions,
         masked_lm_labels) = create_masked_lm_predictions(
             tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
        # 单个训练实例(句子对)
        instance = TrainingInstance(
            tokens=tokens,
            segment_ids=segment_ids,
            is_random_next=is_random_next,
            masked_lm_positions=masked_lm_positions,
            masked_lm_labels=masked_lm_labels)
        instances.append(instance)
      current_chunk = []
      current_length = 0
    i += 1

  return instances


def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
  """
  将句子对截断到最大限制长度
  :param tokens_a: “A”句token, list
  :param tokens_b: “B”句token,list
  :param max_num_tokens: 允许最大token数目, int
  :param rng: 随机数生成器,Random
  :return: 没有返回,原地修改
  """
  while True:
    total_length = len(tokens_a) + len(tokens_b)
    if total_length <= max_num_tokens:
      break

    # 截断长度较长的一句
    trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
    assert len(trunc_tokens) >= 1

    # 为了避免偏差,50%的情况下我们丢弃开头的token,50%的情况下丢弃结尾的token
    if rng.random() < 0.5:
      del trunc_tokens[0]
    else:
      trunc_tokens.pop()

create_masked_lm_predictions

这个函数会根据遮蔽语言模型,随机遮蔽token,并生成预测目标。其中会涉及到全词遮蔽的思想,具体指:某个词片段被遮蔽了,则将此词片段所属的整个词都进行遮蔽。实际过程中,在token化的时候,将非开头的词片段加上“##”前缀,这样可以确定词边界来进行全词遮蔽。具体伪代码逻辑如下:

> 1. 初始化候选索引集合。
> 2. 循环遍历句子中每一个token:
> 3.      如果是“[CLS]”或“[SEP]”则跳过。
> 4.      如果要进行全词遮蔽且当前token有前缀“##”:
> 5.          则将当前token的索引添加到上一个候选集中。
> 6.      否则:
> 7.          以当前索引新建一个候选集,并添加到整体候选集合里。
> 8.  随机打乱候选索引集合。
> 9.  
> 10. 获取随机的被遮蔽数量。
> 11. 初始化被遮蔽后的token序列以及被遮蔽的token序列。
> 12. 循环遍历所有的候选集合:
> 13.     如果遮蔽的数量足够,则跳出。
> 14.     如果加上当前候选集合中token的数量,超过限制,则忽略当前集合。
> 15.     如果集合中存在已经被遮蔽过的token,则忽略当前集合。
> 16.     循环遍历当前集合中每一个索引:
> 17.         当前索引的token以80%的概率变成“[MASK]”,10%的概率变为一个随机token,10%的概率保持不变。
> 18.         保存当前被遮蔽的索引以及真实token值。
> 19. 按顺序获取被遮蔽的token位置及真实值。
> 20. 返回被遮蔽后的token序列、被遮蔽的位置以及真实值。

def create_masked_lm_predictions(tokens, masked_lm_prob,
                                 max_predictions_per_seq, vocab_words, rng):
  """
  根据遮蔽语言模型,生成预测目标
  :param tokens: 原始句子对token, list
  :param masked_lm_prob: 被遮蔽的概率,默认15%, float
  :param max_predictions_per_seq: 允许最多被遮蔽的token数目,int
  :param vocab_words: 词表,list
  :param rng: 随机数生成器,Random
  :return: 被遮蔽后的句子对output_tokens,list
           被遮蔽的位置masked_lm_positions, list
           被遮蔽的真实tokenmasked_lm_labels, list
  """

  cand_indexes = []
  for (i, token) in enumerate(tokens):
    if token == "[CLS]" or token == "[SEP]":
      continue
    # WWM全词遮蔽表示我们会遮蔽跟原词相关的所有词片段。当词被分层词片段,第一个token
    # 没有任何标记,后面的每一个token都会加上“##”前缀。所以一旦看见带##的token,我
    # 们会将它添加到前面一个词索引的集合里。

    # 要注意到,全词遮蔽并没有改变训练代码,我们依然需要独立地预测每一个词片段,
    # 需要计算整个词表的softmax。
    if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
        token.startswith("##")):
      cand_indexes[-1].append(i)
    else:
      cand_indexes.append([i])

  rng.shuffle(cand_indexes)    # 打乱候选索引集

  output_tokens = list(tokens)    # 被遮蔽后的句子对token

  # 被遮蔽数量
  num_to_predict = min(max_predictions_per_seq,
                       max(1, int(round(len(tokens) * masked_lm_prob))))

  masked_lms = []
  covered_indexes = set()
  for index_set in cand_indexes:
    if len(masked_lms) >= num_to_predict:
      break
    # 如果添加当前候选集后,遮蔽的数量超过了最大限制,则跳过此候选集
    if len(masked_lms) + len(index_set) > num_to_predict:
      continue
    is_any_index_covered = False    # 判断token是否出现错
    for index in index_set:
      if index in covered_indexes:
        is_any_index_covered = True
        break
    if is_any_index_covered:
      continue    # 如果token出现过,则跳过,为了尽量遮蔽不同的token
    for index in index_set:
      covered_indexes.add(index)

      masked_token = None
      # 80%会将token换成[MASK]
      if rng.random() < 0.8:
        masked_token = "[MASK]"
      else:
        # 10%会保持原词
        if rng.random() < 0.5:
          masked_token = tokens[index]
        # 10%会替换成一个随机的词
        else:
          masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]

      output_tokens[index] = masked_token    # 将遮蔽位置的词替换掉

      masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
  assert len(masked_lms) <= num_to_predict
  masked_lms = sorted(masked_lms, key=lambda x: x.index)

  masked_lm_positions = []    # 被遮蔽的位置
  masked_lm_labels = []    # 被遮蔽词的真实token
  for p in masked_lms:
    masked_lm_positions.append(p.index)
    masked_lm_labels.append(p.label)

  return (output_tokens, masked_lm_positions, masked_lm_labels)

训练样本写入文件

上面一小节已经得到了训练样本,但格式上还不能直接作为模型的 输入,本小节会将已有样本改成模型可以直接输入的格式,并写入到文件中。为了保持所有训练数据的一致性,会将每句的token id用0补全到最大允许长度,新增input_mask的二值列表,用来判断哪些是增补的,同样的被遮蔽序列也会用0补全到最大允许被遮蔽数目,新增masked_lm_weights的二值列表,来判断哪些是增补的。将补齐后的数据转换成TF Example格式(可参考NLP实战篇之tf2数据输入),然后写入文件中。

def write_instance_to_example_files(instances, tokenizer, max_seq_length,
                                    max_predictions_per_seq, output_files):
  """
  根据训练样本,写入TF样本文件
  :param instances: 训练样本list, [TrainingInstance]
  :param tokenizer: token切分类,FullTokenizer
  :param max_seq_length: 允许最大序列长度,int
  :param max_predictions_per_seq: 允许最大遮蔽数目,int
  :param output_files: TF exmaple输出文件路径, list
  :return:
  """
  writers = []
  for output_file in output_files:    # 文件会很大,所以会选择写入多个文件
    writers.append(tf.python_io.TFRecordWriter(output_file))

  writer_index = 0

  total_written = 0
  for (inst_index, instance) in enumerate(instances):
    input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)    # token转为id
    input_mask = [1] * len(input_ids)    # 有实际输入的位置为1,没有实际输入的后续用0pad
    segment_ids = list(instance.segment_ids)
    assert len(input_ids) <= max_seq_length

    # 将输入序列长度padding到“max_seq_length”,用0padding
    while len(input_ids) < max_seq_length:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    masked_lm_positions = list(instance.masked_lm_positions)    # 被遮蔽的位置
    masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)    # 被遮蔽token的id
    masked_lm_weights = [1.0] * len(masked_lm_ids)    # 实际遮蔽的权重为1,没有实际遮蔽的权重后续用0pad

    # 将遮蔽序列长度padding到“max_predictions_per_seq”,用0padding
    while len(masked_lm_positions) < max_predictions_per_seq:
      masked_lm_positions.append(0)
      masked_lm_ids.append(0)
      masked_lm_weights.append(0.0)

    next_sentence_label = 1 if instance.is_random_next else 0    # 下一句预测的标签

    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(input_ids)    # 输入的id, [max_seq_length]
    features["input_mask"] = create_int_feature(input_mask)    # 输入的mask, [max_seq_length]
    features["segment_ids"] = create_int_feature(segment_ids)    # 第一句、第二句, [max_seq_length]
    features["masked_lm_positions"] = create_int_feature(masked_lm_positions)    # 语言模型中被遮蔽的位置, [max_predictions_per_seq]
    features["masked_lm_ids"] = create_int_feature(masked_lm_ids)    # 遮蔽语言模型的标签, [max_predictions_per_seq]
    features["masked_lm_weights"] = create_float_feature(masked_lm_weights)    # 遮蔽语言模型中被遮蔽的标签的权重, [max_predictions_per_seq]
    features["next_sentence_labels"] = create_int_feature([next_sentence_label])   # 下一句预测的标签, [1]

    # 转为TF example
    tf_example = tf.train.Example(features=tf.train.Features(feature=features))

    writers[writer_index].write(tf_example.SerializeToString())    # 写入文件
    writer_index = (writer_index + 1) % len(writers)

    total_written += 1

    if inst_index < 20:    # 展示前20个样本
      tf.logging.info("*** Example ***")
      tf.logging.info("tokens: %s" % " ".join(
          [tokenization.printable_text(x) for x in instance.tokens]))

      for feature_name in features.keys():
        feature = features[feature_name]
        values = []
        if feature.int64_list.value:
          values = feature.int64_list.value
        elif feature.float_list.value:
          values = feature.float_list.value
        tf.logging.info(
            "%s: %s" % (feature_name, " ".join([str(x) for x in values])))

  for writer in writers:
    writer.close()

  tf.logging.info("Wrote %d total instances", total_written)


def create_int_feature(values):
  # int类型数据转换
  feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
  return feature


def create_float_feature(values):
  # float类型数据转换
  feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
  return feature
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值