TensorFlow2.X——从csv文件中读取数据生成tfrecord文件

从csv文件中读取数据生成tfrecord文件

代码示例:

import tensorflow as tf
import os 
import pprint
import numpy as np
#读取csv文件将其转为tfrecord文件
source_dir = "./customize_generate_csv/"
print(os.listdir(source_dir))

[‘test_00.csv’, ‘test_01.csv’, ‘test_02.csv’, ‘test_03.csv’, ‘test_04.csv’, ‘test_05.csv’, ‘test_06.csv’, ‘test_07.csv’, ‘test_08.csv’, ‘test_09.csv’, ‘train_00.csv’, ‘train_01.csv’, ‘train_02.csv’, ‘train_03.csv’, ‘train_04.csv’, ‘train_05.csv’, ‘train_06.csv’, ‘train_07.csv’, ‘train_08.csv’, ‘train_09.csv’, ‘train_10.csv’, ‘train_11.csv’, ‘train_12.csv’, ‘train_13.csv’, ‘train_14.csv’, ‘train_15.csv’, ‘train_16.csv’, ‘train_17.csv’, ‘train_18.csv’, ‘train_19.csv’, ‘valid_00.csv’, ‘valid_01.csv’, ‘valid_02.csv’, ‘valid_03.csv’, ‘valid_04.csv’, ‘valid_05.csv’, ‘valid_06.csv’, ‘valid_07.csv’, ‘valid_08.csv’, ‘valid_09.csv’]

#定义文件分类函数
def get_filenames_by_prefix(source_dir, prefix_name):
    all_files = os.listdir(source_dir)
    results = []
    for filename in all_files:
        if filename.startswith(prefix_name):
            results.append(os.path.join(source_dir, filename))
    return results

train_filenames = get_filenames_by_prefix(source_dir, "train")
valid_filenames = get_filenames_by_prefix(source_dir, "valid")
test_filenames = get_filenames_by_prefix(source_dir , "test")

pprint.pprint(train_filenames)
pprint.pprint(valid_filenames)
pprint.pprint(test_filenames)

[’./customize_generate_csv/train_00.csv’,
‘./customize_generate_csv/train_01.csv’,
‘./customize_generate_csv/train_02.csv’,
‘./customize_generate_csv/train_03.csv’,
‘./customize_generate_csv/train_04.csv’,
‘./customize_generate_csv/train_05.csv’,
‘./customize_generate_csv/train_06.csv’,
‘./customize_generate_csv/train_07.csv’,
‘./customize_generate_csv/train_08.csv’,
‘./customize_generate_csv/train_09.csv’,
‘./customize_generate_csv/train_10.csv’,
‘./customize_generate_csv/train_11.csv’,
‘./customize_generate_csv/train_12.csv’,
‘./customize_generate_csv/train_13.csv’,
‘./customize_generate_csv/train_14.csv’,
‘./customize_generate_csv/train_15.csv’,
‘./customize_generate_csv/train_16.csv’,
‘./customize_generate_csv/train_17.csv’,
‘./customize_generate_csv/train_18.csv’,
‘./customize_generate_csv/train_19.csv’]
[’./customize_generate_csv/valid_00.csv’,
‘./customize_generate_csv/valid_01.csv’,
‘./customize_generate_csv/valid_02.csv’,
‘./customize_generate_csv/valid_03.csv’,
‘./customize_generate_csv/valid_04.csv’,
‘./customize_generate_csv/valid_05.csv’,
‘./customize_generate_csv/valid_06.csv’,
‘./customize_generate_csv/valid_07.csv’,
‘./customize_generate_csv/valid_08.csv’,
‘./customize_generate_csv/valid_09.csv’]
[’./customize_generate_csv/test_00.csv’,
‘./customize_generate_csv/test_01.csv’,
‘./customize_generate_csv/test_02.csv’,
‘./customize_generate_csv/test_03.csv’,
‘./customize_generate_csv/test_04.csv’,
‘./customize_generate_csv/test_05.csv’,
‘./customize_generate_csv/test_06.csv’,
‘./customize_generate_csv/test_07.csv’,
‘./customize_generate_csv/test_08.csv’,
‘./customize_generate_csv/test_09.csv’]

#定义读取一行csv文件的函数
#n_fields : 数据列数
def parse_csv_line(line, n_fields = 9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    #tf.stack() : 对矩阵进行拼接
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x, y

#定义读取csv文件形成一个Dataset
#n_reader : 并行读取文件数
#n_parse_threads : 解析文件时的并行数
#shuffle_buffer_size : 混排buffe的大小
def csv_reader_dataset(filenames, n_reader=5, batch_size=32, n_parse_threads=5, shuffle_buffer_size=10000 ):
    dataset = tf.data.Dataset.list_files(filenames)
    #repeat(): 无参数表示重复无数次
    #作用:在训练模型时我们不止一次使用数据,要多次使用训练集数据,通过epoch来终止
    dataset = dataset.repeat()
    #interleave() : 读取数据形成一个dataset
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_reader
    )
    dataset.shuffle(shuffle_buffer_size)
    #map():映射到tf.io.decode_csv()函数,解析数据
    dataset = dataset.map(parse_csv_line, num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

batch_size = 32
train_set = csv_reader_dataset(train_filenames, batch_size=batch_size)
valid_set = csv_reader_dataset(valid_filenames, batch_size=batch_size)
test_set = csv_reader_dataset(test_filenames, batch_size=batch_size)
#将数据存储为tfrecord格式

def serialize_example(x, y):
    """converts x, y to tf.train.Example and serialize"""
    #需要注意是否需要转为numpy()形式
    input_features = tf.train.FloatList(value = x.numpy())
    label = tf.train.FloatList(value = y.numpy() )
    features = tf.train.Features(
        feature = {
            "input_features": tf.train.Feature(float_list = input_features),
            "label" : tf.train.Feature(float_list = label)
        }
    )
    example = tf.train.Example(features = features)
    return example.SerializeToString()

#n_shards :存储为n_shards文件
#steps_per_shard : 每个文件有多少条数据
def csv_dataset_to_tfrecords(base_filename, dataset, n_shards, setps_per_shard, compression_type = None):
    options = tf.io.TFRecordOptions(compression_type = compression_type)
    all_filenames = []
    for shard_id in range(n_shards):
        filename_fullpath = '{}_{:05d}-of-{:05d}'.format(base_filename, shard_id, n_shards)
        with tf.io.TFRecordWriter(filename_fullpath, options) as write:
            #需要写steps_per_shard次
            for x_batch, y_batch in dataset.take(setps_per_shard):
                for x_example, y_example in zip(x_batch, y_batch):
                    write.write(serialize_example(x_example, y_example))
        all_filenames.append(filename_fullpath)
    return all_filenames
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecords"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    
train_basement = os.path.join(output_dir, "train")
valid_basement = os.path.join(output_dir, "vaild")
test_basement = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(train_basement, train_set, n_shards, train_steps_per_shard, None)
vaild_tfrecord_filenames = csv_dataset_to_tfrecords(valid_basement, valid_set, n_shards, valid_steps_per_shard, None)
test_tfrecord_filenames = csv_dataset_to_tfrecords(test_basement, test_set, n_shards, test_steps_per_shard, None)

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值