TFRecord数据格式是TensorFlow官方推荐的数据格式,不仅规范化读写,而且提高了IO效率。
1.制作TFRecord数据
原始数据为下图所示,文件夹名为类别标号,文件夹中存放的是各个类的图片:
制作TFRecord的代码为:
import os
import tensorflow as tf
from PIL import Image
import numpy as np
def create_record(inpath,outpath,size): # inpath:原始数据路径 outpath:TFRecord文件输出路径 size:统一图片尺寸,如[256,256]
writer = tf.python_io.TFRecordWriter(outpath)
dirs = os.listdir(inpath)
for index, name in enumerate(dirs):
class_path = inpath +"/"+ name+"/"
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((size[1],size[0]),Image.ANTIALIAS) #将图片按照指定的size进行缩放,统一图片大小,注意resize方法的参数将长度放在前,宽度放在后
img_raw = img.tobytes() #将图片转化为原生bytes
example = tf.train.Example(
features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
2.读取TFRecord文件
读取的代码为:
def read_and_decode(filename,size): # filename:TFRecord文件路径 size:图片尺寸,如[256,256,1],最后一位表示颜色通道,1表示为灰度图像
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, size)
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 图像数据
label = tf.cast(features['label'], tf.int32) # 标签
return img, label
3.分批读取
做CNN训练时,通常将输入数据整理成batch分批训练,而不是一次性全部输入,下面这个方法对读取文件进一步进行了封装,能够分批读取数据。
def get_train_batch(filename,size,min_after_dequeue,batch_size,capacity):
"""
# 获取批量的训练数据
# filename TFRecord文件路径
# size 图像大小
# min_after_dequeue 队列中最小样本数,这个数字越大表示一个batch的数据越乱
# batch_size 一个batch的大小
# capacity 队列的大小,应该大于batch_size
"""
image,label = read_and_decode(filename,size)
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
return image_batch,label_batch
下面为不使用shuffle机制的批次获取方法,获取测试集的时候使用。
def get_test_batch(filename, size, batch_size, capacity):
image, label = read_and_decode(filename, size)
image_batch, label_batch = tf.train.batch(
[image, label], batch_size=batch_size, capacity=capacity)
return image_batch, label_batch
4.读取TFRecord文件并保持为图片
def read_and_save(inpath,outpath,size,num): # inpath:TFRecord文件路径 outpath:输出的图片路径 size:图片大小,如[256,256] num:图片个数
filename_queue = tf.train.string_input_producer([inpath])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, size)
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(num):
example, l = sess.run([image,label])
img=Image.fromarray(example)
img.save(outpath+'/'+str(i)+'_Label_'+str(l)+'.png') # 保存图片
coord.request_stop()
coord.join(threads)