高效读取数据的方法 (TFRecord)

参考了这篇博客的内容,做了些增加修改
TFRecord 是Google官方推荐的一种数据格式,是Google专门为TensorFlow设计的一种数据格式。

实际上,TFRecord是一种二进制文件,其能更好的利用内存,其内部包含了多个tf.train.Example, 而Example是protocol buffer(protobuf) 数据标准 [3] [4] 的实现,在一个Example消息体中包含了一系列的tf.train.feature属性,而 每一个feature 是一个key-value的键值对,其中,key 是string类型,而value 的取值有三种:

bytes_list: 可以存储string 和byte两种数据类型。
float_list: 可以存储float(float32)与double(float64) 两种数据类型 。
int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。

tf.Example 类就是一种将数据表示为{‘string’: value}形式的 message类型,TensorFlow经常使用 tf.Example 来写入,读取 TFRecord数据。

通常情况下,tf.Example中可以使用以下几种格式:
tf.train.BytesList: 可以使用的类型包括 string和byte
tf.train.FloatList: 可以使用的类型包括 float和double
tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64

TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature

tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 	
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) 
 
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
将数据转化为TFRecord文件

Google官方推荐在对于中大数据集来说,先将数据集转化为TFRecord数据(.tfrecords), 这样可加快你在数据读取, 预处理中的速度。

将数据转化为 tfrecord 格式只需要三步, 下面以三个features:context,question, answer为例 :

writer = tf.python_io.TFRecordWriter(out_file_name)  # 1. 定义 writer对象

for data in dataes:
    context = dataes[0]
    question = dataes[1]
    answer = dataes[2]

    """ 2. 定义features """
    example = tf.train.Example(
        features = tf.train.Features(
            feature = {
               'context': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=context)),
               'question': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=question)),
               'answer': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=answer))
            }))

    """ 3. 序列化,写入"""
    serialized = example.SerializeToString()
    writer.write(serialized)
将一张图片转化成TFRecord 文件
import tensorflow as tf
 
def write_test(input, output):
    # 借助于TFRecordWriter 才能将信息写入TFRecord 文件
    writer = tf.python_io.TFRecordWriter(output)
 
    # 读取图片并进行解码
    image = tf.read_file(input)
    image = tf.image.decode_jpeg(image)
 
    with tf.Session() as sess:
        image = sess.run(image)
        shape = image.shape
        # 将图片转换成string
        image_data = image.tostring()
        print(type(image))
        print(len(image_data))
        name = bytes('cat', encoding='utf-8')
        print(type(name))
        # 创建Example对象,并将Feature一一对应填充进去
        example = tf.train.Example(features=tf.train.Features(feature={
             'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
             'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
             'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
        }
        ))
        # 将example序列化成string 类型,然后写入。
        writer.write(example.SerializeToString())
    writer.close()
 
 
if __name__ == '__main__':
    input_photo = 'cat.jpg'
    output_file = 'cat.tfrecord'
    write_test(input_photo, output_file)
TFRecord 文件读取为图片
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
 
def _parse_record(example_photo):
    features = {
        'name': tf.FixedLenFeature((), tf.string),
        'shape': tf.FixedLenFeature([3], tf.int64),
        'data': tf.FixedLenFeature((), tf.string)
    }
    parsed_features = tf.parse_single_example(example_photo,features=features)
    return parsed_features
 
def read_test(input_file):
    # 用dataset读取TFRecords文件
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(_parse_record)
    iterator = dataset.make_one_shot_iterator()
 
    with tf.Session() as sess:
        features = sess.run(iterator.get_next())
        name = features['name']
        name = name.decode()
        img_data = features['data']
        shape = features['shape']
        print("==============")
        print(type(shape))
        print(len(img_data))
 
        # 从bytes数组中加载图片原始数据,并重新reshape,它的结果是 ndarray 数组
        img_data = np.fromstring(img_data, dtype=np.uint8)
        image_data = np.reshape(img_data, shape)
 
        plt.figure()
        # 显示图片
        plt.imshow(image_data)
        plt.show()
 
        # 将数据重新编码成jpg图片并保存
        img = tf.image.encode_jpeg(image_data)
        tf.gfile.GFile('cat_encode.jpg', 'wb').write(img.eval())
 
if __name__ == '__main__':
    read_test("cat.tfrecord")

1,首先使用dataset去读取tfrecord文件

2,在解析example 的时候,用现成的API:tf.parse_single_example

3,用 np.fromstring() 方法就可以获取解析后的string数据,记得把数据还原成 np.uint8

4,用 tf.image.encode_jepg() 方法可以将图片数据编码成 jpeg 格式

5,用 tf.gfile.GFile 对象可以把图片数据保存到本地

6,因为将图片 shape 写入了example 中,所以解析的时候必须指定维度,在这里 [3],不然程序会报错。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值