引言
本次博文目的是记录下tfrecord数据集的制作与使用方式。(踩了无数坑OTZ)
这里贴上一个数据读取的官方教程:Tensorflow导入数据以及使用数据
接下来举个例子说明怎么用tfrecord,假设我要做个图片分类的任务。首先,我这里有一个txt文件,包含着所有图片的路径以及它们的标签。还有一个包含许多图片的文件夹。类似下图这样:
准备好了数据后,就可以制作与使用TFrecored啦~
制作TFrecord
当然是先写个制作TFrecord的函数啦。我们先读取图片信息的txt文件,得到每个图片的路径以及它们的标签,然后对这个图片作一些预处理,最后将图片以及它对应的标签序列化,并建立图片和标签的索引(即以下代码的”img_raw”, “label”)。详见代码。
import random
import tensorflow as tf
from PIL import Image
def create_record(records_path, data_path, img_txt):
# 声明一个TFRecordWriter
writer = tf.python_io.TFRecordWriter(records_path)
# 读取图片信息,并且将读入的图片顺序打乱
img_list = []
with open(img_txt, 'r') as fr:
img_list = fr.readlines()
random.shuffle(img_list)
cnt = 0
# 遍历每一张图片信息
for img_info in img_list:
# 图片相对路径
img_name = img_info.split(' ')[0]
# 图片类别
img_cls = int(img_info.split(' ')[1])
img_path = da