简介
本文介绍TensorFlow的第二种数据导入方法。
为了保持高效,这种方法稍显繁琐。分为如下几个步骤:
- 把所有样本写入二进制文件(只执行一次)
- 创建Tensor
,从二进制文件读取一个样本
- 创建Tensor
,从二进制文件随机读取一个mini-batch
- 把mini-batchTensor
传入网络作为输入节点。
二进制文件
使用tf.python_io.TFRecordWriter
创建一个专门存储tensorflow数据的writer
,扩展名为’.tfrecord’。
该文件中依次存储着序列化的tf.train.Example
类型的样本。
writer = tf.python_io.TFRecordWriter('/tmp/data.tfrecord')
for i in range(0, 10):
# 创建样本example
# ...
serialized = example.SerializeToString() # 序列化
writer.write(serialized) # 写入文件
writer.close()
每一个example
的feature
成员变量是一个dict
,存储一个样本的不同部分(例如图像像素+类标)。以下例子的样本中包含三个键a,b,c
:
# 创建样本example
a_data = 0.618 + i # float
b_data = [2016 + i, 2017+i] # int64
c_data = numpy.array([[0, 1, 2],[3, 4, 5]]) + i # bytes
c_data = c_data.astype(numpy.uint8)
c_raw = c.tostring() # 转化成字符串
example = tf.train.Example(
features=tf.train.Features(
feature={
'a': tf.train.Feature(
float_list=tf.train.FloatList(value&#