参考 Slim读取TFrecord文件 - 云+社区 - 腾讯云
目录
1、TFrecord文件的格式定义
def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def float_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(float_list=tf.train.FloatList(value=values))
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
'image/height': int64_feature(height),
'image/width': int64_feature(width),
}))
这里要注意的是,TFrecord文件的格式定义中,一定要包含“image/encoded”和“image/format”两个关键字 ,第一个关键字的值为图像的二进制值,第二个为图像的格式。
2、使用Slim读取TFrecord文件的步骤
1、设置解码器,一般设置为decoder=slim.tfexample_decoder.TFExampleDecoder(),同时要指定其keys_to_features,和items_to_handlers两个字典参数。key_to_features这个字典需要和TFrecord文件中定义的字典项匹配。items_to_handlers中的关键字可以是任意值,但是它的handler的初始化参数必须要来自于keys_to_features中的关键字。
2、定义数据集类,一般为dataset=slim.dataset.Dataset():它把datasource、reader、decoder、num_samples等参数封装好。
3、定义数据集的数据提供者类,一般为provider=slim.dataset_data_provider.DatasetDataProvider(),需要传入的参数:dataset, num_readers, reader_kwargs, shuffle, num_epochs,common_queue_capacity,common_queue_min, record_key=',seed, scope等。在这个类中:
(1)首先调用_,data=parallel_reader.parallel_read(),这个方法调用tf.train.string_input_producer()得到TFrecord的文件队列(filename_queue),然后根据是否shuffle生成一个公共队列(common queue),用reader_class,common_queue,num_readers,reader_kwargs=reader_kwargs等参数初始化ParallelReader(),然后调用它的read(filename_queuq)方法,这个read()方法先用reader从filename_queue中读取数据然后enqueue到common queue中,然后从common queue中dequeue,从而得到(filename,data)的键值对。
(2)调用items=dataset.decoder.list_items()得到decoder中的items_to_handlers的关键字列表items。
(3)根据1)和2)得到的data和items,调用tensors=dataset.decoder.decode(data, items)。这解码过程中,首先调用example=parsing_ops.parse_single_example(data,keys_to_features)来解析序列化数据得到一个字典特征,然后根据items_to_handlers中传给handler的那些items(这些items来自keys_to_features中的keys),将example中的字典中属于某个handler的多个键值对(因为一个handler用多个items初始化,所以一个handler对应example中多个键值对)交给相应的handler处理,然后每个handler处理完成后返回一个tensor,将所有tensor组成一个列表tensors。
(4)然后将2)中得到的items和3)中得到的tensors进行匹配生成一个字典items_to_tensors。
4、调用provider的get方法从items_to_tensors中获取响应的items对应的tensor,比如[image, label] = provider.get(['image', 'label'])
3、实例
这里我的图片放在D:/test/目录下,有0-9共10张图片。
#coding=utf-8
import tensorflow as tf
import numpy as np
import os
from PIL import Image
slim = tf.contrib.slim
# 创建TFrecord文件
def create_record_file():
train_filename = "train.tfrecords"
if os.path.exists(train_filename):
os.remove(train_filename)
# 创建.tfrecord文件,准备写入
writer = tf.python_io.TFRecordWriter('./'+train_filename)
with tf.Session() as sess:
for i in range(10):
img_raw = tf.gfile.FastGFile("D:/test/"+str(i)+".jpg", 'rb').read()
decode_data = tf.image.decode_jpeg(img_raw)
image_shape= decode_data.eval().shape
example = tf.train.Example(features=tf.train.Features(
feature={
'image/encoded':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'image/format':tf.train.Feature(bytes_list = tf.train.BytesList(value=[b'jpg'])),
'image/width':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[1]])),
'image/height':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[0]])),
'image/label':tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),
}))
writer.write(example.SerializeToString()) # 序列化保存
writer.close()
print ("保存tfrecord文件成功。")
# 使用Slim的方法从TFrecord文件中读取
def read_record_file():
tfrecords_filename = "train.tfrecords"
# 将tf.train.Example反序列化成存储之前的格式。由tf完成
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
'image/label': tf.FixedLenFeature((), tf.int64, default_value=0),
}
# 将反序列化的数据组装成更高级的格式。由slim完成
items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded',
format_key='image/format',
channels=3),
'label': slim.tfexample_decoder.Tensor('image/label'),
'height': slim.tfexample_decoder.Tensor('image/height'),
'width': slim.tfexample_decoder.Tensor('image/width')
}
# 定义解码器,进行解码
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
# 定义dataset,该对象定义了数据集的文件位置,解码方式等元信息
dataset = slim.dataset.Dataset(
data_sources=tfrecords_filename,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=10, # 训练数据的总数
items_to_descriptions=None,
num_classes=10,
)
#使用provider对象根据dataset信息读取数据
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=1,
common_queue_capacity=20,
common_queue_min=1)
# 获取数据
[image, label,height,width] = provider.get(['image', 'label','height','width'])
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(10):
img,l,h,w= sess.run([image,label,height,width])
img = tf.reshape(img, [h,w,3])
print (img.shape)
img=Image.fromarray(img.eval(), 'RGB') # 这里将narray转为Image类,Image转narray:a=np.array(img)
img.save('./'+str(l)+'.jpg') # 保存图片
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
create_record_file()
read_record_file()