在tensorflow/nmt项目中,训练数据和推断数据的输入使用了新的Dataset API,应该是tensorflow 1.2之后引入的API,方便数据的操作。如果你还在使用老的Queue和Coordinator的方式,建议升级高版本的tensorflow并且使用Dataset API。
本教程将从训练数据和推断数据两个方面,详解解析数据的具体处理过程,你将看到文本数据如何转化为模型所需要的实数,以及中间的张量的维度是怎么样的,batch_size和其他超参数又是如何作用的。
训练数据的处理
先来看看训练数据的处理。训练数据的处理比推断数据的处理稍微复杂一些,弄懂了训练数据的处理过程,就可以很轻松地理解推断数据的处理。
训练数据的处理代码位于nmt/utils/iterator_utils.py文件内的get_iterator函数。
函数的参数
我们先来看看这个函数所需要的参数是什么意思:
参数解释
src_dataset
源数据集
tgt_dataset
目标数据集
src_vocab_table
源数据单词查找表,就是个单词和int类型数据的对应表
tgt_vocab_table
目标数据单词查找表,就是个单词和int类型数据的对应表
batch_size
批大小
sos
句子开始标记
eos
句子结尾标记
random_seed
随机种子,用来打乱数据集的
num_buckets
桶数量
src_max_len
源数据最大长度