TFRecord的写入与读取

目录

0、TFRecords是啥

1、数据存入TFRecords流程

2、读取TFRecords流程


0、TFRecords是啥

TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,它能更好的利用内存,更方便复制和移动。

TFRecords是为了将二进制数据和标签(训练的目标类别标签)数据存储在同一个文件中。

文件的保存格式:*.tfrecords

写入文件的内容:每个样本的Example协议块(键值对,类字典格式,表示一个样本数据)
 

1、数据存入TFRecords流程

(1) 构造存储器

(2) 构造每一个样本的Example

(3) 写入序列化的Example

write(record):向文件中写入一个字符串记录(单个样本)

close():关闭文件写入器

注:Example.SerializeToString() 将单个样本的Example协议序列化为字符串,然后write()写入

(1) 建立TFRecord存储器

  • tf.python_io.TFRecordWriter(path)
    • 写入tfrecords文件
    • path:tfrecords文件的保存路径及文件名
    • return:文件写入器

(2) 构造每个样本的Example协议块

  • tf.train.Example(features=None) 样本的Example协议块
    • Features: tf.train.Features类型的特征实例
    • return: example 协议块
  • tf.train.Features(features=None) 每个样本信息的键值对
    • feature: 字典类型,key为要保存的名字,value为tf.train.Feature类型实例
    • reutrn:Features类型
  • tf.train.Feature(**options) 单个特征值的内容
  • **options: 例如:
    • bytes_list=tf.train.BytesList(value=[Bytes])
    • int64_list=tf.train.Int64List(value=[Value])
    • float_list=tf.train.FloatList(value=[Value])
  • 可以定义如下代码来进行简化
# 生成字符串型的属性。
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# 生成整数型的属性。
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 生成实数型的属性。
def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

demo.py:

import tensorflow as tf
import os 
 
# 读取二进制文件(包含多个样本),并对每个样本构造example协议块并写入tfrecords文件。
# 二进制表示的图片数据。图片尺寸32*32,3个通道。每个图片样本的大小固定(1+32*32*3=3073字节)第一个字节表示图片类别,32*32*3个字节表示图片像素值
# 数据来源:http://www.cs.toronto.edu/~kriz/cifar.html
 
# 找到数据文件,放入列表   路径+名字->列表当中
file_names = os.listdir("./data/")
print(file_names)  # ['data_batch_1.bin', 'data_batch_2.bin', 'data_batch_3.bin']
# 拼接路径和文件名
filename_list = [os.path.join("./data/", file) for file in file_names]
 
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(filename_list)
 
# 2、构造二进制文件读取器,读取内容。 每个样本的字节数
reader = tf.FixedLengthRecordReader(3073)  # 1+32*32*3=3073 (图片)
key, value = reader.read(file_queue)  # value表示一个样本的数据
print(value)  # Tensor("ReaderReadV2:1", shape=(), dtype=string)
 
# 3、解码内容,二进制文件内容的解码 (将读取的字符串类型张量转换成uint8类型的张量)
decoded_data = tf.decode_raw(value, tf.uint8)
print(decoded_data)  # Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)
 
# 4、分割出标签和图片数据,划分特征值和目标值 (decoded_data的第一个字节表示图片类别(标签))
label = tf.cast(tf.slice(decoded_data, [0], [1]), tf.int32)
image = tf.slice(decoded_data, [1], [3072])  # 从下标1开始切,取3072列。
print(image)  # Tensor("Slice_1:0", shape=(3072,), dtype=uint8)
 
# 5、可以对图片的特征数据进行形状的改变 [3072] --> [32, 32, 3]
image_reshape = tf.reshape(image, [32, 32, 3])
print(label)  # Tensor("Cast:0", shape=(1,), dtype=int32)
print(image_reshape)  # Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8)
 
# 6、批处理数据
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
print(image_batch)  # Tensor("batch:0", shape=(10, 32, 32, 3), dtype=uint8)
print(label_batch)  # Tensor("batch:1", shape=(10, 1), dtype=int32)
 
 
# 开启会话运行结果
with tf.Session() as sess:
    # 创建一个线程协调器
    coord = tf.train.Coordinator()
 
    # 开启读文件的子线程
    threads = tf.train.start_queue_runners(sess, coord=coord)
 
    # 1、建立TFRecord存储器
    writer = tf.python_io.TFRecordWriter("./mydata/dog.tfrecords")
 
    # 2、循环将每个样本构造example协议块并写入tfrecords文件
    for i in range(10):  # 一个批次中有10个样本
        # 取出第i个图片数据的特征值和目标值
        image = image_batch[i].eval().tostring()  # eval()相当于sess.run()。 tostring()将ndarray转换成字符串类型。
        label = int(label_batch[i].eval()[0])  # 取出目标值 (label_batch是二维的张量)
 
        # 构造一个样本的example协议块
        example = tf.train.Example(features=tf.train.Features(feature={
            "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
        }))
        #example = tf.train.Example(features=tf.train.Features(feature={
        #    "image": _bytes_feature(image),
        #    "label": _int64_feature(label)
        #}))
 
        # 写入单个样本
        writer.write(example.SerializeToString())  # SerializeToString()将样本的协议块序列化成字符串
 
    # 关闭
    writer.close()
 
    # 结束子线程
    coord.request_stop()
    # 等待子线程结束
    coord.join(threads)

 

2、读取TFRecords流程

(1)构造TFRecords解析器

(2)解析Example

(3)转换格式,bytes解码

解析TFRecords的Example协议块

  • tf.parse_single_example(serialized, features=None,name=None)
    • serialized=读取出的序列化的Example协议块内容
    • features:dirct字典类型,键为读取出的特征名,值为FixedLenFeature类型
  • tf.FixedLenFeature(shape,dtype) 表示单个特征内容
    • shape:输入数据的形状,一般不指定,为空列表
    • dtype:输入数据的类型,与存储进文件的类型要一致!!!(只能是float32,int64,string)

demo.py:

import tensorflow as tf
import os
 
 
# 读取tfrecords文件。
 
# 找到数据文件,放入列表   路径+名字->列表当中
file_names = os.listdir("./mydata/")
print(file_names)  # ['dog.tfrecords']
# 拼接路径和文件名
filename_list = [os.path.join("./mydata/", file) for file in file_names]
 
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(filename_list)
 
# 2、构造文件阅读器,读取example协议块
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)  # value是序列化后的example协议块(一个样本对应一个协议块)
 
# 3、解析example协议块。 解析成字典类型(键值对形式)的样本信息
features = tf.parse_single_example(value, features={
    "image": tf.FixedLenFeature([], tf.string),  # 要与存储的key和数据类型保持对应。
    "label": tf.FixedLenFeature([], tf.int64)
})
 
# 4、解码内容,解码成数值类型。 如果读取的内容格式是string类型,就需要解码, 如果是int64,float32不需要解码
image = tf.decode_raw(features["image"], tf.uint8)  # string类型解码成uint8类型。
 
# 固定图片(样本)的形状 (批处理需要数据形状固定)
image_reshape = tf.reshape(image, [32, 32, 3])  # 3表示图片3个通道
 
label = tf.cast(features["label"], tf.int32)  # 转换类型
print(image_reshape)  # Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8)
print(label)  # Tensor("Cast:0", shape=(), dtype=int32)
 
# 进行批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
 
 
# 开启会话运行结果
with tf.Session() as sess:
    # 创建一个线程协调器
    coord = tf.train.Coordinator()
 
    # 开启读文件的子线程
    threads = tf.train.start_queue_runners(sess, coord=coord)
 
    # 打印读取的内容
    print(sess.run(label_batch))
    '''
    [5 6 0 9 4 3 1 2 9 7]
    '''
    print(sess.run(image_batch))
    '''
    [[[[178 178 178]
       [178 179 179]
       [179 180 180]
       ...
       [176 175 173]
       [171 168 166]
       [163 159 155]]]]
    '''
 
    # 结束子线程
    coord.request_stop()
    # 等待子线程结束
    coord.join(threads)
 

参考

https://www.cnblogs.com/hellcat/p/8146748.html

TFRecord 统一输入数据格式和组合数据: https://www.jianshu.com/p/b5687b88a3ea

写入: https://blog.csdn.net/houyanhua1/article/details/88231017

读取: https://blog.csdn.net/houyanhua1/article/details/88236001

https://www.tensorflow.org/guide/datasets?hl=zh-cn

https://my.oschina.net/u/3800567/blog/1637798

https://blog.csdn.net/yeqiustu/article/details/79793454

https://www.2cto.com/kf/201702/604326.html

https://www.cnblogs.com/whu-zeng/p/6293589.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值