话不多说,干就完了。
TFRecord是什么?
TFRecord是tensorflow中定义的一种数据格式,这种格式的数据在模型训练中方便将数据喂给模型。在初学tensorflow时肯定都实验过tensorflow中自带的mnist手写数字识别的例子,在那个例子中需要下载一个mnist二进制的数据集,那个数据集是经过处理的二进制文件,我们可以直接使用,并不需要关心图像数据的加载、预处理等操作。也就是说在mnist这个例子中,我们直接跳过接触原始数据的机会,仅仅关心网络模型的训练。但是在实际项目中却没有这样现成的数据文件。TFRecord就是这样一种将原始图像数据保存为二进制数据文件的数据格式,只有将原始图像数据转换为二进制的数据(可以简单理解为矩阵)才能被模型使用。
TFRecord的数据存储结构:
TFRecord文件由Example、Features、Feature组成,结构如下:
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
TFRecord使用流程:
原始图像生成TFRecord数据文件:
图像数据格式如下:
生成TFRecordwenj文件:
def save_to_tfrecord():
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
dataset_path = os.path.join(".", "dataset")
tfrecord_file_name = os.path.join(".", "tfrecords_files", "cat_and_dog_datasets.tfrecord")
tfrecord_writer = tf.python_io.TFRecordWriter(path=tfrecord_file_name)
for root, dirs, files in os.walk(dataset_path):
if len(dirs) == 0:
cls_name = os.path.split(root)[-1]
for image_name in files:
image_path = os.path.join(root, image_name)
img = cv2.imread(filename=image_path)
if img.ndim == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
img_pixels = img.shape[0] * img.shape[1]
img_raw = img.tostring()
print(img_pixels)
img_example = tf.train.Example(
features=tf.train.Features(
feature={
"pixels": _int64_feature(img_pixels),
"labels": _bytes_feature(cls_name.encode("utf8")),
"image_raw": _bytes_feature(img_raw)
}
)
)
tfrecord_writer.write(img_example.SerializeToString())
tfrecord_writer.close()
加载TFRecord数据文件:
def load_from_tfrecord():
tfrecord_reader = tf.TFRecordReader()
tfrecord_file_name = ["tfrecords_files/cat_and_dog_datasets.tfrecord"]
tfrecord_file_queue = tf.train.string_input_producer(
string_tensor=tfrecord_file_name,
shuffle=True,
num_epochs=200
)
_, serialized_example = tfrecord_reader.read(queue=tfrecord_file_queue)
features = tf.parse_single_example(
serialized=serialized_example,
features={
"pixels": tf.FixedLenFeature(shape=[], dtype=tf.int64),
"labels": tf.FixedLenFeature(shape=[], dtype=tf.string),
"image_raw": tf.FixedLenFeature(shape=[], dtype=tf.string)
}
)
pixels = tf.cast(features["pixels"], tf.int64)
labels = tf.cast(features["labels"], tf.string)
img_raw = tf.decode_raw(bytes=features["image_raw"], out_type=tf.uint8)
init_op = [tf.local_variables_initializer(), tf.global_variables_initializer()]
sess.run(init_op)
# 这两行代码至关重要,表示使用协程启动数据加载线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(4):
print(i * "#####")
img_pixels, img_labels, img = sess.run([pixels, labels, img_raw])
print("img pixels -> %s" % img_pixels)
print("img labels -> %s" % img_labels)
print("img matrix -> %s" % img.shape)
plt.imshow(img.reshape((224, 224, 3)))
plt.show()
参考:https://blog.csdn.net/zxyhhjs2017/article/details/82556746