tensor2tensor to train all the need is attention model
参数定义:
通过 create_experiment调用train_utils.create_hparams,调用函数problem_hparams.problem_hparams,之后调用transormer,transformer调用common_hparams得到基本的模型参数,并在transformer中补充参数.
模型调用:
trainer_utils._cond_on_index调用fn(cur_idx),fn为trainer_utils.model_fn.nth_model,调用
t2t_model._with_timing.fn_with_timing调用fn,fn调用transformer.model_fn_body得到loss,logits.
训练数据生成:
函数为tensor2tensor/bin/t2t-datagen.py
对于每个任务和数据集,都分别定义了其数据生成函数,所有任务数据处理函数定义在词典_SUPPORTED_PROBLEM_GENERATORS中
例如要训练en-fr的attention is all you need模型,其数据处理函数定义为:
“wmt_enfr_tokens_32k”: ( lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15))
enfr_wordpiece_token_generator定义在wmt中,代码如下:
def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
"""Instance of token generator for the WMT en->fr task."""
symbolizer_vocab = generator_utils.get_or_generate_vocab(
tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size)
datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
tag = "train" if train else "dev"
data_path = _compile_data(tmp_dir, datasets, "wmt_enfr_tok_%s" % tag)
return token_generator(data_path + ".lang1", data_path + ".lang2",
symbolizer_vocab, 1)
首先generator_utils.get_or_generate_vocab函数生成词典,_ENFR_TRAIN_DATASETS 为包含输入数据的词典,token_generator函数读取训练数据,并将其转换为数字:
def token_generator(source_path, target_path, token_vocab, eos=None):
eos=token_vocab['</S>']
eos_list = [] if eos is None else [eos]
with tf.gfile.GFile(source_path, mode="r") as source_file:
with tf.gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
while source and target:
source_ints = word_num(source.strip('\n'),token_vocab) + eos_list#change text to number and end with eos1
target_ints = word_num(target.strip('\n'),token_vocab) + eos_list
slen=max(len(source_ints),len(target_ints))
if slen>=5 and slen<=20:#control the training sentence to be [5,20]
yield {"inputs": source_ints, "targets": target_ints}
source, target = source_file.readline(), target_file.readline()