TFRecord生成与读取

1. TFRecord的作用

为了高效的读取数据,可以将数据进行序列化存储,这样也便于网络流式读取数据,TFRecord就是一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式,这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。TFRecord是谷歌推荐的## 标题一种常用的存储二进制序列数据的文件格式。
  TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

2. TFRecord生成方式

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocolbuffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriterclass写入到TFRecords文件。

将一张图片转化成TFRecord,代码如下:

# _*_coding:utf-8_*_

import tensorflow as tf


def write_test(input, output):
    """

    :param input:
    :param output:
    """
    # 借助于TFRecordWriter 才能将信息写入TFRecord 文件
    writer = tf.python_io.TFRecordWriter(output)

    # 读取图片并进行解码
    image = tf.read_file(input)
    image = tf.image.decode_jpeg(image)

    with tf.Session() as sess:
        image = sess.run(image)
        shape = image.shape
        # 将图片转换成string
        image_data = image.tostring()

        name = bytes('cat', encoding='utf-8')

        # 创建Example对象,并将Feature一一对应填充进去
        example = tf.train.Example(features=tf.train.Features(feature={
            'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
            'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
            'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
        }
        ))
        # 将example序列化成string 类型,然后写入。
        writer.write(example.SerializeToString())
    writer.close()


if __name__ == '__main__':
    input_photo = '/home/zhy/0000.png '
    output_file = 'cat.tfrecord'
    write_test(input_photo, output_file)

下面解释一下代码:
1, 将图片解码,然后转化成string数据,然后填充进去。
2, Feature 的value 是列表,所以记得加上 []
3, example需要调用 SerializetoString() 进行序列化后才行

3. TFRecord的读取

从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocolbuffer)解析为张量。

TFRecord 文件读取为图片,代码如下:

# _*_coding:utf-8_*_
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


def _parse_record(example_photo):
    features = {
        'name': tf.FixedLenFeature((), tf.string),
        'shape': tf.FixedLenFeature([3], tf.int64),
        'data': tf.FixedLenFeature((), tf.string)
    }
    parsed_features = tf.parse_single_example(example_photo, features=features)
    return parsed_features


def read_test(input_file):
    # 用dataset读取TFRecords文件
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(_parse_record)
    dataset = dataset.batch(3, drop_remainder=True)
    iterator = dataset.make_one_shot_iterator()  # 从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次

    with tf.Session() as sess:
        features = sess.run(iterator.get_next())
        name = features['name']
        name = name.decode()
        img_data = features['data']
        shape = features['shape']

        # 从bytes数组中加载图片原始数据,并重新reshape,它的结果是 ndarray 数组
        img_data = np.fromstring(img_data, dtype=np.uint8)
        image_data = np.reshape(img_data, shape)

        plt.figure()
        # 显示图片
        plt.imshow(image_data)
        plt.show()

        # 将数据重新编码成jpg图片并保存
        img = tf.image.encode_jpeg(image_data)
        tf.gfile.GFile('cat_encode.png', 'wb').write(img.eval())


if __name__ == '__main__':
    read_test("/home/zhy/tfrecords/test_0_2.tfrecords")

下面解释一下代码:

1,首先使用dataset去读取tfrecord文件

2,在解析example 的时候,用现成的API:tf.parse_single_example

3,用 np.fromstring() 方法就可以获取解析后的string数据,记得把数据还原成 np.uint8

4,用 tf.image.encode_jepg() 方法可以将图片数据编码成 jpeg 格式

5,用 tf.gfile.GFile 对象可以把图片数据保存到本地

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值