tensor2tensor

41 篇文章 1 订阅
28 篇文章 0 订阅

tensor2tensor to train all the need is attention model

参数定义:

通过 create_experiment调用train_utils.create_hparams,调用函数problem_hparams.problem_hparams,之后调用transormer,transformer调用common_hparams得到基本的模型参数,并在transformer中补充参数.

模型调用:

trainer_utils._cond_on_index调用fn(cur_idx),fn为trainer_utils.model_fn.nth_model,调用

t2t_model._with_timing.fn_with_timing调用fn,fn调用transformer.model_fn_body得到loss,logits.

训练数据生成:

函数为tensor2tensor/bin/t2t-datagen.py

对于每个任务和数据集,都分别定义了其数据生成函数,所有任务数据处理函数定义在词典_SUPPORTED_PROBLEM_GENERATORS中

例如要训练en-fr的attention is all you need模型,其数据处理函数定义为:

“wmt_enfr_tokens_32k”: ( lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15))

enfr_wordpiece_token_generator定义在wmt中,代码如下:
def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
  """Instance of token generator for the WMT en->fr task."""
  symbolizer_vocab = generator_utils.get_or_generate_vocab(
      tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size)
  datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
  tag = "train" if train else "dev"
  data_path = _compile_data(tmp_dir, datasets, "wmt_enfr_tok_%s" % tag)
  return token_generator(data_path + ".lang1", data_path + ".lang2",
                         symbolizer_vocab, 1)

首先generator_utils.get_or_generate_vocab函数生成词典,_ENFR_TRAIN_DATASETS 为包含输入数据的词典,token_generator函数读取训练数据,并将其转换为数字:

def token_generator(source_path, target_path, token_vocab, eos=None):
  eos=token_vocab['</S>']
  eos_list = [] if eos is None else [eos]
  with tf.gfile.GFile(source_path, mode="r") as source_file:
    with tf.gfile.GFile(target_path, mode="r") as target_file:
      source, target = source_file.readline(), target_file.readline()
      while source and target:
        source_ints = word_num(source.strip('\n'),token_vocab) + eos_list#change text to number and end with eos1
        target_ints = word_num(target.strip('\n'),token_vocab) + eos_list
        slen=max(len(source_ints),len(target_ints))
        if slen>=5 and slen<=20:#control the training sentence to be [5,20]
            yield {"inputs": source_ints, "targets": target_ints}
        source, target = source_file.readline(), target_file.readline()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值