数据准备
Trax中没有写好的数据数据预处理脚本,所以要自己写数据预处理的过程,这里我就直接使用tensorflow_official_nlp_transformer使用示例中生成TTRecoard数据
# 获取训练语料
batch_size = 8
max_length = 100
static_batch = True
model_dir = './data_dir/trax_nlp/train_dir/'
_READ_RECORD_BUFFER = 8*1000*1000
def _load_records(filename):
"""Read file and return a dataset of tf.Examples."""
return tf.data.TFRecordDataset(filename, buffer_size=_READ_RECORD_BUFFER)
def _parse_example(serialized_example):
"""Return inputs and targets Tensors from a serialized tf.Example."""
data_fields = {
"inputs": tf.io.VarLenFeature(tf