下载数据cifar-10数据集(形式为分文件夹存放原始图片),文件目录结构如下所示:
我们只选择演示其中的train文件夹中的图片。
生成TFRecord文件:
# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import glob
import tensorflow as tf
# 指定图片数据路径
PATH = 'cifar-10'
# 指定输出TFRecord文件
OUT_DIR = 'tfrecord_file'
# 以下函数将一个value转换为与tf适配的格式
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
if isinstance(value, type(tf.constant(0))):
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
# 利用一张图片创建一个example
def make_example(image_string, label):
# 如果我们想获取图片的高宽和深度,则将读取到的数据先解码为jpeg图片格式,然后用shape取值
# image_shape = tf.image.decode_jpeg(image_string).shape
feature = {
# 'height': _int64_feature(image_shape[0]),
# 'width': _int64_feature(image_shape[1]),
# 'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string)
}
exam = tf.train.Example(features=tf.train.Features(feature=feature))
return exam
# 定义一个TFRecordWriter实例
write = tf.io.TFRecordWriter(OUT_DIR)
# 创建一个index:img_name的字典
index_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
cate_list = ['airplane', 'mobilephone', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img_dict = zip(index_list, cate_list)
# 循环所有类别的图片文件夹,生成tfrecord文件
for index, img_cate in img_dict:
# 生成训练集每个图片的路径
train_img_list = glob.glob(os.path.join(PATH, 'train/{}/*.jpg'.format(index)))
# test_img_list = glob.glob(os.path.join(PATH, 'test/{}/*.jpg'.format(index)))
print("start to make train file index = {}:".format(index))
count = 0
# 开始循环读取每张图片,然后连同label,一起打包到TFRecord文件中
for per_img_dir in train_img_list:
# 使用gfile.Gfile来读图片文件
with tf.io.gfile.GFile(per_img_dir, 'rb') as per_img_fp:
img_data = per_img_fp.read()
# 为每一张图片+label创建一个example
example = make_example(img_data, index)
# 打印看看example具体内容结构
if count == 0:
for line in str(example).split('\n')[:15]:
print(line)
# 将每个example写入到write中
write.write(example.SerializeToString())
count += 1
if count % 1000 == 0:
print(count)
write.close()
其中每一个图片所生成的example结构如下:
features {
feature {
key: "image_raw"
value {
bytes_list {
value: "\377\330\377\340\000\020JFIF\000\001\001\000\000\001\000\001\000\000\377\333\000C\000\010\006\006\007\006\005\010\007\007\007\t\t\010\n\014\024\r\014\013\013\014\032
......
2)g\230\266\025Q\to\312\276\252\360\037\206\233B\360\355\231\276\211N\252\321\177\244JynNB\223\355\300\374)F-=\004\225\217\377\331"
}
}
}
feature {
key: "label"
value {
int64_list {
value: 0
}
}
}
其中包含我们再feature中指定的label和image_raw。
读取TFRecord中的数据: