目录
在训练模型的时候,一般会将数据预处理转换成tfrecord格式,负责I/O操作的CPU和进行数值运行计算的GPU相互之间可以并行工作,保证GPU高的利用率。以下是对特征是定长和变长读写tfrecord方式。
1 写tfrecor方式
一般会将数据按照模型训练所需要的方式对输入x和label标签进行tfrecord格式转换。主要有定长和变长两种方式,根据实际应用和需求决定。若输入的每个example的input 是变长的,比如每个example的输入特征索引个数不是相同的,则可以按照变长的方式转换,否则按照定长的方式转换。
1.1 变长特征转tfrecord
import collections
writer = tf.python_io.TFRecordWriter('data.tfrecord')
def toTF(data):
'''
data是一个dict,假设其中key有input_x和input_y,
对应的value是索引list
'''
features = collections.OrderedDict()
input_x = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x"])))
features["input_x"] = tf.train.FeatureList(feature=input_x)
input_y = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y"])))
features["input_y"] = tf.train.FeatureList(feature=input_y)
sequence_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list=features))
writer.write(sequence_example.SerializeToString())
以下方式实现与上面方式等价:
def toTF_v2(data)
sequence_example = tf.train.SequenceExample()
input_x = sequence_example.feature_lists.feature_list["input_x"]
input_y = sequence_example.feature_lists.feature_list["input_y"]
for x in data["input_x"]:
input_x.feature.add().int64_list.value.append(x)
for y in data["input_y"]:
input_y.feature.add().int64_list.value.append(y)
writer.write(sequence_example.SerializeToString())
1.2 定长特征转tfrecord
def toTF_fixed(data):
features = collections.OrderedDict()
features["input_x"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x")))
features["input_y"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y")))
example = tf.train.Example(features=tf.train.Features(feature=features))
write.write(example.SerializeToString())
2 读tfrecord
和写trrecord一样,也分定长和变长方式,如果写tfrecord是定长方式,则读tfrecord也需要定长方式。读写方式需要保持一致。
2.1 变长方式读tfrecord
需要定义特征的格式,如果是变长则定义tf.FixedLenSequenceFeature类型特征
import tensorflow as tf
features = {
'input_x': tf.FixedLenSequenceFeature([], tf.int64)
'input_y': tf.FixedLenSequenceFeature([], tf.int64)
}
2.2 定长方式读tfrecord
定长方式用tf.FixedLenFeature类型
seq_length = 10
features = {
'input_x': tf.FixedLenFeature([seq_length], tf.int64).
'input_y': tf.FixedLenFeature([seq_length], tf.int64
}
3 从hdfs中读取批量tfrecord文件
当训练数据量级很大时,一般转tfrecord试用分布式方式处理数据,提高效率。训练模型的时候,可以从远程,例如hdfs上读取批量文件。以下是从hdfs上批量读取tfrecord文件。
def input_fn_builder(file_path, num_cpu_threads, seq_length, num_class, batch_size):
'''
其中file_path是hdfs上文件的路径,比如data目录下的所有tfrecord文件
读的是定长的feature
'''
features = {
'input_x': tf.FixedLenFeature([seq_length], tf.int64),
'input_y': tf.FixedLenFeature([seq_length], tf.int64),
}
def _decode_record(record):
# 一个样本解析
example = tf.io.parse_single_example(record, features)
multi_label_enc = tf.one_hot(indices=example["input_y"], depth=num_class)
example["input_y"] = tf.reduce_sum(multi_label_enc, axis=0)
return example
def _decode_batch_record(batch_record):
# 一个batch样本解析
batch_example = tf.io.parse_example(serialized=batch_record, features=features)
multi_label_enc = tf.one_hot(indices=batch_example["input_y"], depth=num_class)
batch_example["input_y"] = tf.reduce_sum(multi_label_enc, axis=1)
return batch_example
def input_fn(params):
# d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
d = tf.data.Dataset.list_files(file_path)
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.appley(
tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=True,
cycle_length=num_cpu_threads))
d = d.apply(
tf.contrib.data.map_and_batch(
lambda record: _decode_record(record),
batch_size = batch_size,
num_parallel_batches=num_cpu_threads,
drop_remainder=True))
return d
def input_fn_v2(params):
d = tf.data.Dataset.list_files(file_path)
d = d.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=num_cpu_threads, block_length=128).\
batch(batch_size).map(_decode_batch_record, num_parallel_calls=tf.data.experimental.AUTOTRUE).prefetch(
tf.data.experimental.AUTOTUNE).repeat()
return d
return input_fn
#return input_fn_v2
上面提供了两个解析函数,input_fn和input_fn_v2两种方式都可行,配合estimator方式训练,可以使得CPU读取数据与GPU训练数据之间可以并行处理,减少等待时间,提高GPU的利用率,加快训练速度。解析tfrecord文件时,有下面四种方式,根据自己具体的数据格式进行选择:
- 解析单个样本,定长特征:tf.io.parse_single_example()
- 解析单个样本,变长特征:tf.io.parse_single_sequence_example()
- 解析批量样本,定长特征:tf.io.parse_example()
- 解析批量样本,定长特征:tf.io.parse_sequence_example()