[数据读取]1.tfRecord的生成和读取

一、背景

利用tensorflow训练深度模型,数据的处理和输入是必须的一个步骤,为了高效的读取数据,用官网文档的话说就是tfrecord对数据进行序列化并将其存储在一组可线性读取的文件中,从硬盘流式读取。对于内存的利用有很大的作用。
本文将从推荐模型通常使用的是结构化的数据输入,一步步拆解做演示和解释。
也可以看直接看知乎 [数据读取]1.tfRecord的生成和读取 - 知乎

二、TFRecord生成

  1. 数据说明
    工程上一般数据输入是从hive表拉取的文件列表,整理成数据文件列表形式,以文件列表输入,ifrecord文件列表输出。一般要点包括
    (1)多线程读取解析。
    (2)tfrecord写入函数
    (3)tfrecord行字段解析
  2. 示例代码演示
from multiprocessing import Pool
import os
import glob
import numpy as np
import pandas as pd

//(3)tfrecord行字段解析	
def parse_line(line):
    feature_list = line.strip('\n').split('\t')
    features = dict()
    feature.update({'fea1':tf.train.Feature(float_list=tf.train.FloatList(value=np.array(feature_list[0],dtype=np.float)))})
    feature.update({'fea2':tf.train.Feature(bytes_list=tf.train.BytesList(value=feature_list[1]))})

    feature.update({'label':tf.train.Feature(int64_list=tf.train.Int64List(value=feature_list[2]))})
    return features

//(2)tfrecord写入函数
def tfrecord_process(params):
    infile_name = params[0]
    output_dir = params[1]
	
    outfile_name = f'{output_dir}.tfrecord'
    tfrecord_out = tf.io.TFRecordWriter(outfile_name) //定义输出文件
    with open(infile_name, 'r', encoding='utf-8') as fp:
        for line in fp.readlines():
            feat = parse_line(line) //(3)tfrecord行字段解析
            example = tf.train.Example(features=tf.train.Features(feature=feat))
            serialized = example.SerializeToString()
            tfrecord_out.write(serialized)
    tfrecord_out.close()

//(1)多线程读取解析。
if __name_ == "__main__":
    //定义多线程读取
    input_dir = args.input_dir
    output_dir = args.output_dir
    input_files = glob.glob(input_dir)
    parameters = [[item, output_dir] for item in input_files]

    pool = Pool(args.threads)
    pool.map(tfrecord_process, parameters)
    pool.close()
    pool.join()

函数解释
example = tf.train.Example(features=tf.train.Features(feature=feat))
参数:features(字典key-value形式,key是特征名字,value是特征值),value=tf.train.Feature()
tf.train.feature属性,每一个feature 是一个key-value的键值对,其中,key 是string类型,value 的取值有三种:
* bytes_list: 可以存储string 和byte两种数据类型。
* float_list: 可以存储float(float32)与double(float64) 两种数据类型 。
* int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64。

三、TFRecord读取

  1. 读取说明
    模型训练直接获取tfrecord目录进行读取即可,假如模型训练使用tf.estimator的形式,读取就作为tf.estimator.TrainSpec的参数input_fn进行构造传递即可。
  2. input_fn代码构造示例
//解析特征输入
def parse_exmp(serial_exmp):
    y = fc.numeric_column('label', default_value=0, dtype=tf.int64)
    fea_cols= [y]
    fea_cols+= gl_feature_columns
    feature_spec = tf.feature_column.make_parse_example_spec(fea_cols)
    x = tf.parse_single_example(serial_exmp, features=feature_spec)
    y = x.pop('label')
    return x, y

//构造输入函数
def input_fn(filenames, batch_size,shuffle,epochs=1):
    files = tf.data.Dataset.list_files(filenames, shuffle=shuffle, seed=42)
    line_pre = files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=20, block_length=20, sloppy=True)) \
        .apply(tf.contrib.data.shuffle_and_repeat(buffer_size=batch_size * 100, count=epochs)) \
        .apply(tf.contrib.data.map_and_batch(map_func=lambda x: parse_exmp(x), batch_size=batch_size)) \
        .prefetch(batch_size * 2)
    return line_pre

def main():   
    batch_size = 128
    shuffle = False
    input_fn_train = lambda: input_fn(train_files, batch_size, shuffle, epochs) 
    train_spec = tf.estimator.TrainSpec(
        input_fn=input_fn_train,
    )

四、总结

暂时总结tfrecord的生成和输入,完成对模型数据这方面的代码构造,后面继续研究其他构造输入的方法或者更优方式,大家也多提意见,再做总结修改。

下一步计划:对tf.keras的数据输入的读取方面的技术总结。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

感冒灵pp

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值