tensorflow读取数据-tfrecord格式

概述:关于tensorflow读取数据,官网给出了三种方法:1、供给数据:在tensorflow程序运行的每一步,让python代码来供给数据2、从文件读取数据:建立输入管线从文件中读取数据3、预加载数据:如果数据量不太大,可以在程序中定义常量或者变量来保存所有的数据。这里主要介绍一种比较通用、高效的数据读取方法,就是tensorflow官方推荐的标准格式:tfrecord。tfrecor
摘要由CSDN通过智能技术生成

概述:


关于tensorflow读取数据,官网给出了三种方法:
1、供给数据:在tensorflow程序运行的每一步,让python代码来供给数据
2、从文件读取数据:建立输入管线从文件中读取数据
3、预加载数据:如果数据量不太大,可以在程序中定义常量或者变量来保存所有的数据。

这里主要介绍一种比较通用、高效的数据读取方法,就是tensorflow官方推荐的标准格式:tfrecord。

tfrecord数据文件


tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。
tfrecord文件包含了tf.train.Example 协议缓冲区(protocol buffer,协议缓冲区包含了特征 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议缓冲区(protocol buffer),将协议缓冲区序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。
tf.train.Example的定义如下:

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

从上述代码可以看出,tf.train.Example中包含了属性名称到取值的字典,其中属性名称为字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。

代码实现


将数据保存为tfrecord格式

具体来说,首先需要给定tfrecord文件名称,并创建一个文件:

tfrecords_filename = './tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename) # 创建.tfrecord文件,准备写入

之后就可以创建一个循环来依次写入数据:

    for i in range(100):
        img_raw = np.random.random_integers(0,255,size=(7,30)) # 创建7*30,取值在0-255之间随机数组
        img_raw = img_raw.tostring()
        example = tf.train.Example(features=tf.train.Features(
                feature={
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                }))
        writer.write(example.SerializeToString()) 

    writer.close()

example = tf.train.Example()这句将数据赋给了变量example(可以看到里面是通过字典结构实现的赋值),然后用writer.write(example.SerializeToString()) 这句实现写入。

值得注意的是赋值给example的数据格式。从前面tf.train.Example的定义可知,tfrecord支持整型、浮点数和二进制三种格式,分别是

tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
tf.train.Feature(bytes_list = tf.train.FloatList(value=[float_scalar]))

例如图片等数组形式(array)的数据,可以保存为numpy array的格式,转换为string,然后保存到二进制格式的feature中。对于单个的数值(scalar),可以直接赋值。这里value=[×]的[]非常重要,也就是说输入的必须是列表(list)。当然,对于输入数据是向量形式的,可以根据数据类型(float还是int)分别保存。并且在保存的时候还可以指定数据的维数。

读取tfrecord数据

从TFRecords文件中读取数据, 首先需要用tf.train.string_input_producer生成一个解析队列。之后调用tf.TFRecordReader的tf.parse_single_example解析器。如下图:

  • 43
    点赞
  • 140
    收藏
    觉得还不错? 一键收藏
  • 28
    评论
评论 28
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值