demo.py(TFRecord,读取TFRecord,TFRecordReader()):
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # 设置警告级别
# 读取tfrecords文件。
# 找到数据文件,放入列表 路径+名字->列表当中
file_names = os.listdir("./mydata/")
print(file_names) # ['dog.tfrecords']
# 拼接路径和文件名
filename_list = [os.path.join("./mydata/", file) for file in file_names]
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(filename_list)
# 2、构造文件阅读器,读取example协议块
reader = tf.TFRecordReader()
key, value = reader.read(file_queue) # value是序列化后的example协议块(一个样本对应一个协议块)
# 3、解析example协议块。 解析成字典类型(键值对形式)的样本信息
features = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string), # 要与存储的key和数据类型保持对应。
"label": tf.FixedLenFeature([], tf.int64)
})
# 4、解码内容,解码成数值类型。 如果读取的内容格式是string类型,就需要解码, 如果是int64,float32不需要解码
image = tf.decode_raw(features["image"], tf.uint8) # string类型解码成uint8类型。
# 固定图片(样本)的形状 (批处理需要数据形状固定)
image_reshape = tf.reshape(image, [32, 32, 3]) # 3表示图片3个通道
label = tf.cast(features["label"], tf.int32) # 转换类型
print(image_reshape) # Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8)
print(label) # Tensor("Cast:0", shape=(), dtype=int32)
# 进行批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
# 开启会话运行结果
with tf.Session() as sess:
# 创建一个线程协调器
coord = tf.train.Coordinator()
# 开启读文件的子线程
threads = tf.train.start_queue_runners(sess, coord=coord)
# 打印读取的内容
print(sess.run(label_batch))
'''
[5 6 0 9 4 3 1 2 9 7]
'''
print(sess.run(image_batch))
'''
[[[[178 178 178]
[178 179 179]
[179 180 180]
...
[176 175 173]
[171 168 166]
[163 159 155]]]]
'''
# 结束子线程
coord.request_stop()
# 等待子线程结束
coord.join(threads)