一、背景
利用tensorflow训练深度模型,数据的处理和输入是必须的一个步骤,为了高效的读取数据,用官网文档的话说就是tfrecord对数据进行序列化并将其存储在一组可线性读取的文件中,从硬盘流式读取。对于内存的利用有很大的作用。
本文将从推荐模型通常使用的是结构化的数据输入,一步步拆解做演示和解释。
也可以看直接看知乎 [数据读取]1.tfRecord的生成和读取 - 知乎
二、TFRecord生成
- 数据说明
工程上一般数据输入是从hive表拉取的文件列表,整理成数据文件列表形式,以文件列表输入,ifrecord文件列表输出。一般要点包括
(1)多线程读取解析。
(2)tfrecord写入函数
(3)tfrecord行字段解析 - 示例代码演示
from multiprocessing import Pool
import os
import glob
import numpy as np
import pandas as pd
//(3)tfrecord行字段解析
def parse_line(line):
feature_list = line.strip('\n').split('\t')
features = dict()
feature.update({'fea1':tf.train.Feature(float_list=tf.train.FloatList(value=np.array(feature_list[0],dtype=np.float)))})
feature.update({'fea2':tf.train.Feature(bytes_list=tf.train.BytesList(value=feature_list[1]))})
feature.update({'label':tf.train.Feature(int64_list=tf.train.Int64List(value=feature_list[2]))})
return features
//(2)tfrecord写入函数
def tfrecord_process(params):
infile_name = params[0]
output_dir = params[1]
outfile_name = f'{output_dir}.tfrecord'
tfrecord_out = tf.io.TFRecordWriter(outfile_name) //定义输出文件
with open(infile_name, 'r', encoding='utf-8') as fp:
for line in fp.readlines():
feat = parse_line(line) //(3)tfrecord行字段解析
example = tf.train.Example(features=tf.train.Features(feature=feat))
serialized = example.SerializeToString()
tfrecord_out.write(serialized)
tfrecord_out.close()
//(1)多线程读取解析。
if __name_ == "__main__":
//定义多线程读取
input_dir = args.input_dir
output_dir = args.output_dir
input_files = glob.glob(input_dir)
parameters = [[item, output_dir] for item in input_files]
pool = Pool(args.threads)
pool.map(tfrecord_process, parameters)
pool.close()
pool.join()
函数解释
example = tf.train.Example(features=tf.train.Features(feature=feat))
参数:features(字典key-value形式,key是特征名字,value是特征值),value=tf.train.Feature()
tf.train.feature属性,每一个feature 是一个key-value的键值对,其中,key 是string类型,value 的取值有三种:
* bytes_list: 可以存储string 和byte两种数据类型。
* float_list: 可以存储float(float32)与double(float64) 两种数据类型 。
* int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64。
三、TFRecord读取
- 读取说明
模型训练直接获取tfrecord目录进行读取即可,假如模型训练使用tf.estimator的形式,读取就作为tf.estimator.TrainSpec的参数input_fn进行构造传递即可。 - input_fn代码构造示例
//解析特征输入
def parse_exmp(serial_exmp):
y = fc.numeric_column('label', default_value=0, dtype=tf.int64)
fea_cols= [y]
fea_cols+= gl_feature_columns
feature_spec = tf.feature_column.make_parse_example_spec(fea_cols)
x = tf.parse_single_example(serial_exmp, features=feature_spec)
y = x.pop('label')
return x, y
//构造输入函数
def input_fn(filenames, batch_size,shuffle,epochs=1):
files = tf.data.Dataset.list_files(filenames, shuffle=shuffle, seed=42)
line_pre = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=20, block_length=20, sloppy=True)) \
.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=batch_size * 100, count=epochs)) \
.apply(tf.contrib.data.map_and_batch(map_func=lambda x: parse_exmp(x), batch_size=batch_size)) \
.prefetch(batch_size * 2)
return line_pre
def main():
batch_size = 128
shuffle = False
input_fn_train = lambda: input_fn(train_files, batch_size, shuffle, epochs)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn_train,
)
四、总结
暂时总结tfrecord的生成和输入,完成对模型数据这方面的代码构造,后面继续研究其他构造输入的方法或者更优方式,大家也多提意见,再做总结修改。
下一步计划:对tf.keras的数据输入的读取方面的技术总结。