一 TFRecords文件
TFRecords文件是TensorFlow专门的文件存储及读取格式,其中包含了tf.train.Example 协议内存块(protocol buffer),存储特征值与数据内容。通过tf.python_io.TFRecordWrite类,可以获取相应的数据并将其填入到Example协议内存块中,最终生成TFRecords文件。简单地说,tf.train.Example有若干数据特征(Features),而Features又有若干Feature字典,其中Feature只接受FloatList, ByteList, Int64List三种数据格式。TFRecords文件就是通过一个包含着二进制文件的数据文件,将特征与标签进行保存以便于Tensorflow读取。
二 案例分析
本例实现对daisy,dandelion,rose进行分类,项目结构如下:
其中,Data文件夹下有daisy,dandelion,rose三类植物,每类四张JPG格式图片,TFRecords_Writer.py负责创建TFRecords文件,TFRecords_Reader.py负责读取TFRecords文件。
1. TFRecords_Writer.py
import os
import tensorflow as tf
from PIL import Image
path = "Data"
dirnames = os.listdir(path)
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for name in dirnames:
class_path = path + os.sep + name
for img_name in os.listdir(class_path):
img_path = class_path + os.sep + img_name
img = Image.open(img_path)
img = img.resize((500, 500))
img_raw = img.tobytes() # 将图片转化成二进制形式
example = tf.train.Example(
features=(tf.train.Features(
feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name.encode()])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}
))
)
writer.write(example.SerializeToString())
2. TFRecords_Reader.py
import tensorflow as tf
import cv2
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.string),
'image': tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['image'], tf.uint8)
img = tf.reshape(img, [500, 500, 3])
img = tf.cast(img, tf.float32) * (1. / 128) - 0.5
label = tf.cast(features['label'], tf.string)
return img, label
filename = "train.tfrecords"
img, label = read_and_decode(filename)
img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=1, capacity=10, min_after_dequeue=1)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess)
for _ in range(10):
# 同时取图片及标签,否则图片与标签无法对应
val, label = sess.run([img_batch, label_batch])
val.resize((500, 500, 3))
cv2.imshow("cool", val)
cv2.waitKey()
print(label)
注:读取数据的格式必须与写入TFRecords文件的数据格式一致。