python如何读取tfrecord_TFRecord格式存储数据与队列读取实例

Tensor Flow官方网站上提供三种读取数据的方法

1. 预加载数据:在Tensor Flow图中定义常量或变量来保存所有数据,将数据直接嵌到数据图中,当训练数据较大时,很消耗内存。

x1=tf.constant([0,1])

x2=tf.constant([1,0])

y=tf.add(x1,x2)

2.填充数据:使用sess.run()的feed_dict参数,将Python产生的数据填充到后端,之前的MNIST数据集就是通过这种方法。也有消耗内存,数据类型转换耗时的缺点。

3. 从文件读取数据:从文件中直接读取,让队列管理器从文件中读取数据。分为两步

先把样本数据写入TFRecords二进制文件

再从队列中读取

TFRecord是TensorFlow提供的一种统一存储数据的二进制文件,能更好的利用内存,更方便的复制和移动,并且不需要单独的标记文件。下面通过代码来将MNIST转换成TFRecord的数据格式,其他数据集也类似。

#生成整数型的属性

def _int64_feature(value):

return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#生成字符串型的属性

def _bytes_feature(value):

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to(data_set,name):

'''

将数据填入到tf.train.Example的协议缓冲区(protocol buffer)中,将协议缓冲区序列

化为一个字符串,通过tf.python_io.TFRecordWriter写入TFRecords文件

'''

images=data_set.images

labels=data_set.labels

num_examples=data_set.num_examples

if images.shape[0]!=num_examples:

raise ValueError ('Imagessize %d does not match label size %d.'\

%(images.shape[0],num_examples))

rows=images.shape[1] #28

cols=images.shape[2] #28

depth=images.shape[3] #1 是黑白图像

filename = os.path.join(FLAGS.directory, name + '.tfrecords')

#使用下面语句就会将三个文件存储为一个TFRecord文件,当数据量较大时,最好将数据写入多个文件

#filename="C:/Users/dbsdz/Desktop/TF练习/TFRecord"

print('Writing',filename)

writer=tf.python_io.TFRecordWriter(filename)

for index in range(num_examples):

image_raw=images[index].tostring() #将图像矩阵化为一个字符串

#写入协议缓冲区,height、width、depth、label编码成int 64类型,image——raw编码成二进制

example=tf.train.Example(features=tf.train.Features(feature={

'height':_int64_feature(rows),

'width':_int64_feature(cols),

'depth':_int64_feature(depth),

'label':_int64_feature(int(labels[index])),

'image_raw':_bytes_feature(image_raw)}))

writer.write(example.SerializeToString()) #序列化字符串

writer.close()

上面程序可以将MNIST数据集中所有的训练数据存储到三个TFRecord文件中。结果如下图

从队列中TFRecord文件,过程分三步

1. 创建张量,从二进制文件中读取一个样本

2. 创建张量,从二进制文件中随机读取一个mini-batch

3. 把每一批张量传入网络作为输入节点

具体代码如下

def read_and_decode(filename_queue): #输入文件名队列

reader=tf.TFRecordReader()

_,serialized_example=reader.read(filename_queue)

#解析一个example,如果需要解析多个样例,使用parse_example函数

features=tf.parse_single_example(

serialized_example,

#必须写明feature里面的key的名称

features={

#TensorFlow提供两种不同的属性解析方法,一种方法是tf.FixedLenFeature,

#这种方法解析的结果为一个Tensor。另一个方法是tf.VarLenFeature,

#这种方法得到的解析结果为Sp

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值