Tensor Flow官方网站上提供三种读取数据的方法
1. 预加载数据:在Tensor Flow图中定义常量或变量来保存所有数据,将数据直接嵌到数据图中,当训练数据较大时,很消耗内存。
如
x1=tf.constant([0,1])
x2=tf.constant([1,0])
y=tf.add(x1,x2)
2.填充数据:使用sess.run()的feed_dict参数,将Python产生的数据填充到后端,之前的MNIST数据集就是通过这种方法。也有消耗内存,数据类型转换耗时的缺点。
3. 从文件读取数据:从文件中直接读取,让队列管理器从文件中读取数据。分为两步
先把样本数据写入TFRecords二进制文件
再从队列中读取
TFRecord是TensorFlow提供的一种统一存储数据的二进制文件,能更好的利用内存,更方便的复制和移动,并且不需要单独的标记文件。下面通过代码来将MNIST转换成TFRecord的数据格式,其他数据集也类似。
#生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#生成字符串型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to(data_set,name):
'''
将数据填入到tf.train.Example的协议缓冲区(protocol buffer)中,将协议缓冲区序列
化为一个字符串,通过tf.python_io.TFRecordWriter写入TFRecords文件
'''
images=data_set.images
labels=data_set.labels
num_examples=data_set.num_examples
if images.shape[0]!=num_examples:
raise ValueError ('Imagessize %d does not match label size %d.'\
%(images.shape[0],num_examples))
rows=images.shape[1] #28
cols=images.shape[2] #28
depth=images.shape[3] #1 是黑白图像
filename = os.path.join(FLAGS.directory, name + '.tfrecords')
#使用下面语句就会将三个文件存储为一个TFRecord文件,当数据量较大时,最好将数据写入多个文件
#filename="C:/Users/dbsdz/Desktop/TF练习/TFRecord"
print('Writing',filename)
writer=tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw=images[index].tostring() #将图像矩阵化为一个字符串
#写入协议缓冲区,height、width、depth、label编码成int 64类型,image——raw编码成二进制
example=tf.train.Example(features=tf.train.Features(feature={
'height':_int64_feature(rows),
'width':_int64_feature(cols),
'depth':_int64_feature(depth),
'label':_int64_feature(int(labels[index])),
'image_raw':_bytes_feature(image_raw)}))
writer.write(example.SerializeToString()) #序列化字符串
writer.close()
上面程序可以将MNIST数据集中所有的训练数据存储到三个TFRecord文件中。结果如下图
从队列中TFRecord文件,过程分三步
1. 创建张量,从二进制文件中读取一个样本
2. 创建张量,从二进制文件中随机读取一个mini-batch
3. 把每一批张量传入网络作为输入节点
具体代码如下
def read_and_decode(filename_queue): #输入文件名队列
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
#解析一个example,如果需要解析多个样例,使用parse_example函数
features=tf.parse_single_example(
serialized_example,
#必须写明feature里面的key的名称
features={
#TensorFlow提供两种不同的属性解析方法,一种方法是tf.FixedLenFeature,
#这种方法解析的结果为一个Tensor。另一个方法是tf.VarLenFeature,
#这种方法得到的解析结果为Sp