TensorFlow入门(十-I)tfrecord 固定维度数据读写

本例代码:https://github.com/yongyehuang/Tensorflow-Tutorial/tree/master/python/the_use_of_tfrecord

关于 tfrecord 的使用,分别介绍 tfrecord 进行三种不同类型数据的处理方法。
- 维度固定的 numpy 矩阵
- 可变长度的 序列 数据
- 图片数据

在 tf1.3 及以后版本中,推出了新的 Dataset API, 之前赶实验还没研究,可能以后都不太会用下面的方式写了。这些代码都是之前写好的,因为注释中都写得比较清楚了,所以直接上代码。

tfrecord_1_numpy_writer.py

# -*- coding:utf-8 -*- 

import tensorflow as tf
import numpy as np
from tqdm import tqdm

'''tfrecord 写入数据.
将固定shape的矩阵写入 tfrecord 文件。这种形式的数据写入 tfrecord 是最简单的。
refer: http://blog.csdn.net/qq_16949707/article/details/53483493
'''

# **1.创建文件,可以创建多个文件,在读取的时候只需要提供所有文件名列表就行了
writer1 = tf.python_io.TFRecordWriter('../data/test1.tfrecord')
writer2 = tf.python_io.TFRecordWriter('../data/test2.tfrecord')

"""
有一点需要注意的就是我们需要把矩阵转为数组形式才能写入
就是需要经过下面的 reshape 操作
在读取的时候再 reshape 回原始的 shape 就可以了
"""
X = np.arange(0, 100).reshape([50, -1]).astype(np.float32)
y = np.arange(50)

for i in tqdm(xrange(len(X))):  # **2.对于每个样本
    if i >= len(y) / 2:
        writer = writer2
    else:
        writer = writer1
    X_sample = X[i].tolist()
    y_sample = y[i]
    # **3.定义数据类型,按照这里固定的形式写,有float_list(好像只有32位), int64_list, bytes_list.
    example = tf.train.Example(
        features=tf.train.Features(
            feature={'X': tf.train.Feature(float_list=tf.train.FloatList(value=X_sample)),
                     'y': tf.train.Feature(int64_list=tf.train.Int64List(value=[y_sample]))}))
    # **4.序列化数据并写入文件中
    serialized = example.SerializeToString()
    writer.write(serialized)

print('Finished.')
writer1.close()
writer2.close()

tfrecord_1_numpy_reader.py

# -*- coding:utf-8 -*- 

import tensorflow as tf

'''read data
从 tfrecord 文件中读取数据,对应数据的格式为固定shape的数据。
'''

# **1.把所有的 tfrecord 文件名列表写入队列中
filename_queue = tf.train.string_input_producer(['../data/test1.tfrecord', '../data/test2.tfrecord'], num_epochs=None,
                                                shuffle=True)
# **2.创建一个读取器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# **3.根据你写入的格式对应说明读取的格式
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'X': tf.FixedLenFeature([2], tf.float32),  # 注意如果不是标量,需要说明数组长度
                                       'y': tf.FixedLenFeature([], tf.int64)}     # 而标量就不用说明
                                   )
X_out = features['X']
y_out = features['y']

print(X_out)
print(y_out)
# **4.通过 tf.train.shuffle_batch 或者 tf.train.batch 函数读取数据
"""
在shuffle_batch 函数中,有几个参数的作用如下:
capacity: 队列的容量,容量越大的话,shuffle 得就更加均匀,但是占用内存也会更多
num_threads: 读取进程数,进程越多,读取速度相对会快些,根据个人配置决定
min_after_dequeue: 保证队列中最少的数据量。
   假设我们设定了队列的容量C,在我们取走部分数据m以后,队列中只剩下了 (C-m) 个数据。然后队列会不断补充数据进来,
   如果后勤供应(CPU性能,线程数量)补充速度慢的话,那么下一次取数据的时候,可能才补充了一点点,如果补充完后的数据个数少于
   min_after_dequeue 的话,不能取走数据,得继续等它补充超过 min_after_dequeue 个样本以后才让取走数据。
   这样做保证了队列中混着足够多的数据,从而才能保证 shuffle 取值更加随机。
   但是,min_after_dequeue 不能设置太大,否则补充时间很长,读取速度会很慢。
"""
X_batch, y_batch = tf.train.shuffle_batch([X_out, y_out], batch_size=2,
                                          capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# **5.启动队列进行数据读取
# 下面的 coord 是个线程协调器,把启动队列的时候加上线程协调器。
# 这样,在数据读取完毕以后,调用协调器把线程全部都关了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in xrange(5):
    _X_batch, _y_batch = sess.run([X_batch, y_batch])
    print('** batch %d' % i)
    print('_X_batch:', _X_batch)
    print('_y_batch:', _y_batch)
    y_outputs.extend(_y_batch.tolist())
print(y_outputs)

# **6.最后记得把队列关掉
coord.request_stop()
coord.join(threads)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值