目录
0、TFRecords是啥
TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,它能更好的利用内存,更方便复制和移动。
TFRecords是为了将二进制数据和标签(训练的目标类别标签)数据存储在同一个文件中。
文件的保存格式:*.tfrecords
写入文件的内容:每个样本的Example协议块(键值对,类字典格式,表示一个样本数据)
1、数据存入TFRecords流程
(1) 构造存储器
(2) 构造每一个样本的Example
(3) 写入序列化的Example
write(record):向文件中写入一个字符串记录(单个样本)
close():关闭文件写入器
注:Example.SerializeToString() 将单个样本的Example协议序列化为字符串,然后write()写入
(1) 建立TFRecord存储器
- tf.python_io.TFRecordWriter(path)
- 写入tfrecords文件
- path:tfrecords文件的保存路径及文件名
- return:文件写入器
(2) 构造每个样本的Example协议块
- tf.train.Example(features=None) 样本的Example协议块
- Features: tf.train.Features类型的特征实例
- return: example 协议块
- tf.train.Features(features=None) 每个样本信息的键值对
- feature: 字典类型,key为要保存的名字,value为tf.train.Feature类型实例
- reutrn:Features类型
- tf.train.Feature(**options) 单个特征值的内容
- **options: 例如:
- bytes_list=tf.train.BytesList(value=[Bytes])
- int64_list=tf.train.Int64List(value=[Value])
- float_list=tf.train.FloatList(value=[Value])
- 可以定义如下代码来进行简化
# 生成字符串型的属性。
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 生成整数型的属性。
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 生成实数型的属性。
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
demo.py:
import tensorflow as tf
import os
# 读取二进制文件(包含多个样本),并对每个样本构造example协议块并写入tfrecords文件。
# 二进制表示的图片数据。图片尺寸32*32,3个通道。每个图片样本的大小固定(1+32*32*3=3073字节)第一个字节表示图片类别,32*32*3个字节表示图片像素值
# 数据来源:http://www.cs.toronto.edu/~kriz/cifar.html
# 找到数据文件,放入列表 路径+名字->列表当中
file_names = os.listdir("./data/")
print(file_names) # ['data_batch_1.bin', 'data_batch_2.bin', 'data_batch_3.bin']
# 拼接路径和文件名
filename_list = [os.path.join("./data/", file) for file in file_names]
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(filename_list)
# 2、构造二进制文件读取器,读取内容。 每个样本的字节数
reader = tf.FixedLengthRecordReader(3073) # 1+32*32*3=3073 (图片)
key, value = reader.read(file_queue) # value表示一个样本的数据
print(value) # Tensor("ReaderReadV2:1", shape=(), dtype=string)
# 3、解码内容,二进制文件内容的解码 (将读取的字符串类型张量转换成uint8类型的张量)
decoded_data = tf.decode_raw(value, tf.uint8)
print(decoded_data) # Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)
# 4、分割出标签和图片数据,划分特征值和目标值 (decoded_data的第一个字节表示图片类别(标签))
label = tf.cast(tf.slice(decoded_data, [0], [1]), tf.int32)
image = tf.slice(decoded_data, [1], [3072]) # 从下标1开始切,取3072列。
print(image) # Tensor("Slice_1:0", shape=(3072,), dtype=uint8)
# 5、可以对图片的特征数据进行形状的改变 [3072] --> [32, 32, 3]
image_reshape = tf.reshape(image, [32, 32, 3])
print(label) # Tensor("Cast:0", shape=(1,), dtype=int32)
print(image_reshape) # Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8)
# 6、批处理数据
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
print(image_batch) # Tensor("batch:0", shape=(10, 32, 32, 3), dtype=uint8)
print(label_batch) # Tensor("batch:1", shape=(10, 1), dtype=int32)
# 开启会话运行结果
with tf.Session() as sess:
# 创建一个线程协调器
coord = tf.train.Coordinator()
# 开启读文件的子线程
threads = tf.train.start_queue_runners(sess, coord=coord)
# 1、建立TFRecord存储器
writer = tf.python_io.TFRecordWriter("./mydata/dog.tfrecords")
# 2、循环将每个样本构造example协议块并写入tfrecords文件
for i in range(10): # 一个批次中有10个样本
# 取出第i个图片数据的特征值和目标值
image = image_batch[i].eval().tostring() # eval()相当于sess.run()。 tostring()将ndarray转换成字符串类型。
label = int(label_batch[i].eval()[0]) # 取出目标值 (label_batch是二维的张量)
# 构造一个样本的example协议块
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
#example = tf.train.Example(features=tf.train.Features(feature={
# "image": _bytes_feature(image),
# "label": _int64_feature(label)
#}))
# 写入单个样本
writer.write(example.SerializeToString()) # SerializeToString()将样本的协议块序列化成字符串
# 关闭
writer.close()
# 结束子线程
coord.request_stop()
# 等待子线程结束
coord.join(threads)
2、读取TFRecords流程
(1)构造TFRecords解析器
(2)解析Example
(3)转换格式,bytes解码
解析TFRecords的Example协议块
- tf.parse_single_example(serialized, features=None,name=None)
- serialized=读取出的序列化的Example协议块内容
- features:dirct字典类型,键为读取出的特征名,值为FixedLenFeature类型
- tf.FixedLenFeature(shape,dtype) 表示单个特征内容
- shape:输入数据的形状,一般不指定,为空列表
- dtype:输入数据的类型,与存储进文件的类型要一致!!!(只能是float32,int64,string)
demo.py:
import tensorflow as tf
import os
# 读取tfrecords文件。
# 找到数据文件,放入列表 路径+名字->列表当中
file_names = os.listdir("./mydata/")
print(file_names) # ['dog.tfrecords']
# 拼接路径和文件名
filename_list = [os.path.join("./mydata/", file) for file in file_names]
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(filename_list)
# 2、构造文件阅读器,读取example协议块
reader = tf.TFRecordReader()
key, value = reader.read(file_queue) # value是序列化后的example协议块(一个样本对应一个协议块)
# 3、解析example协议块。 解析成字典类型(键值对形式)的样本信息
features = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string), # 要与存储的key和数据类型保持对应。
"label": tf.FixedLenFeature([], tf.int64)
})
# 4、解码内容,解码成数值类型。 如果读取的内容格式是string类型,就需要解码, 如果是int64,float32不需要解码
image = tf.decode_raw(features["image"], tf.uint8) # string类型解码成uint8类型。
# 固定图片(样本)的形状 (批处理需要数据形状固定)
image_reshape = tf.reshape(image, [32, 32, 3]) # 3表示图片3个通道
label = tf.cast(features["label"], tf.int32) # 转换类型
print(image_reshape) # Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8)
print(label) # Tensor("Cast:0", shape=(), dtype=int32)
# 进行批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
# 开启会话运行结果
with tf.Session() as sess:
# 创建一个线程协调器
coord = tf.train.Coordinator()
# 开启读文件的子线程
threads = tf.train.start_queue_runners(sess, coord=coord)
# 打印读取的内容
print(sess.run(label_batch))
'''
[5 6 0 9 4 3 1 2 9 7]
'''
print(sess.run(image_batch))
'''
[[[[178 178 178]
[178 179 179]
[179 180 180]
...
[176 175 173]
[171 168 166]
[163 159 155]]]]
'''
# 结束子线程
coord.request_stop()
# 等待子线程结束
coord.join(threads)
参考
https://www.cnblogs.com/hellcat/p/8146748.html
TFRecord 统一输入数据格式和组合数据: https://www.jianshu.com/p/b5687b88a3ea
写入: https://blog.csdn.net/houyanhua1/article/details/88231017
读取: https://blog.csdn.net/houyanhua1/article/details/88236001
https://www.tensorflow.org/guide/datasets?hl=zh-cn
https://my.oschina.net/u/3800567/blog/1637798
https://blog.csdn.net/yeqiustu/article/details/79793454