前言
关于tfrecord在这里做个记录
参考:
https://blog.csdn.net/weixin_42052460/article/details/80714539
https://blog.csdn.net/xiezongsheng1990/article/details/82713014
提示:以下是本篇文章正文内容,下面案例可供参考
一、tfrecord是什么?
tfrecord是一种存储文件的格式,有利于TensorFlow加快其内容的读取速度和内存管理。
二、使用步骤
1.引入库
代码如下(示例):
import tensorflow as tf
import os
import cv2
import matplotlib.pyplot as plt
2.tfrecord之写入
代码如下(示例):
def tfrecordwriter(input_path, record_path):
writer = tf.python_io.TFRecordWriter(record_path)
imagename = [f for f in os.listdir(input_path) if f.endswith('.jpg')]
print(imagename)
for image in imagename:
img = cv2.imread(input_path + image)
img.resize(1000,1000,3)
rawimg = img.tobytes()
features = tf.train.Features(feature = {"img_name":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image.encode()])),
"img_type": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b'jpg'])),
"img": tf.train.Feature(
bytes_list=tf.train.BytesList(value=[rawimg]))})
# 'width':tf.train.Feature(int64_list = tf.train.Int64List(value = [img[0]])),
# 'height':tf.train.Feature(int64_list = tf.train.Int64List(value = [img[1]]))})
example = tf.train.Example(features=features)
#example.SerializeToString()是将example压缩成二进制文件
writer.write(example.SerializeToString())
writer.close()
example = next(tf.python_io.tf_record_iterator(record_path))
print(tf.train.Example.FromString(example)) #打印example的内容
tf.python_io.TFRecordWriter(path)#创建一个tfrecord的句柄
tf.train.Features#创建具有多个feature的数据特征
tf.train.Example#创建example协议块,该协议块用上述的feature表示
example.SerializeToString()#将example压缩成二进制文件
writer.write(example.SerializeToString())#写入开始创建的句柄
writer.close()#关闭文件
写入数据流
原本内容转化为bytes形式——>数据块合并成tf.train.Features(类似dict形式)
——>放入example压缩序列化——>写入句柄
3.tfrecord之读出
def read_tfcord(tfrecord):
filename_queue = tf.train.string_input_producer([tfrecord]) #创建读入文件队列(记得是队列[])
reader = tf.TFRecordReader()#创建读入句柄
_, serialized_example = reader.read(filename_queue)#读取文件队列,是压缩序列化的example
feature = tf.parse_single_example(serialized_example, features={"img_name": tf.FixedLenFeature([], tf.string),
"img":tf.FixedLenFeature([], tf.string)})
#解析压缩的example,通过字典形式获取内容(类型为Tensor),获取具体指需要建立session
imagename = feature["img_name"]
img = tf.decode_raw(feature["img"],tf.uint8)#解编码,因为之前把数据转换成二进制流
img = tf.reshape(img,[1000,1000,3])
# plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# plt.show()
# img = tf.decode_raw(feature['rawimg'], tf.uint8)
# shape = [feature['width'], feature['height'], 3]
# img = tf.reshape(img, shape)
print(imagename)
return imagename, img
读出数据流
创建队列和读句柄——>读取序列化的example
——>解析example——>解编码
4.整体实现
if __name__ == '__main__':
record_path = 'train.tfrecord'
input_path = 'image/'
with tf.Session() as sess:
tfrecordwriter(input_path, record_path)
img_name, img = read_tfcord(record_path)
# 开启一个协调器
coord = tf.train.Coordinator()
#队列填充
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for i in range(3):
name, image = sess.run([img_name, img])
print(name)
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# plt.imshow(image)
plt.show()
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
#协调器发出所有线程的终止信号
coord.request_stop()
# 把开启的线程加入主线程,等待thread结束
coord.join(threads)
tensor内容通过建立session才能输出。
其中:
Coordinator类用来管理在Session中的多个线程,可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,该线程捕获到这个异常之后就会终止所有线程。使用 tf.train.Coordinator来创建一个线程管理器(协调器)对象。
QueueRunner类用来启动tensor的入队线程,可以用来启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中,具体执行函数是 tf.train.start_queue_runners , 只有调用 tf.train.start_queue_runners之后,才会真正把tensor推入内存序列中,供计算单元调用,否则会由于内存序列为空,数据流图会处于一直等待状态。
总结
一起加油!