实例5:将图片文件制作成TFRecord数据集
有两个文件夹,放置男人和女人的照片
要求
- 将两个文件夹中的图片制成TFRecord格式的数据集
- 从数据集中读取数据,并将得到的图片数据保存到本地文件
TFRecord格式与TensorFlow框架绑定,通用性较差
但它是一种非常高效的数据持久化方法,尤其对需要预处理的样本集
将处理后的数据用TFRecord格式保存训练,可以大大提高训练模型的运行效率
1. 样本介绍
- 文件夹的名称可以当做样本标签
- 文件夹中的具体图片文件可被当做具体的样本数据
2. 代码实现:读取样本文件的目录及标签
import os
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
from PIL import Image
from sklearn.utils import shuffle
import numpy as np
from tqdm import tqdm
def load_sample(sample_dir, shuffleflag = True):
print("loading sample dataset..")
lfilenames = []
labelsnames = []
for (dirpath, dirnames, filenames) in os.walk(sample_dir):
for filename in filenames:
filename_path = os.sep.join([dirpath, filename])
lfilenames.append(filename_path)
labelsnames.append(dirpath.split('\\')[-1])
lab = list(sorted(set(labelsnames)))
labdict = dict(zip(lab, list(range(len(lab)))))
labels = [labdict[i] for i in labelsnames]
if shuffleflag == True:
return (shuffle(np.asarray(lfilenames), np.asarray(labels))), np.asarray(lab)
return (np.asarray(lfilenames), np.asarray(labels)), np.asarray(lab)
directory = 'man_woman\\'
(filenames, labels), _ = load_sample(directory, shuffleflag=False)
引入第三方库tqdm,以便在批处理过程中显示进度
3. 代码实现:定义函数生成TFRecord
def makeTFRec(filenames,labels):
writer = tf.python_io.TFRecorWriter("mydata.tfrecords")
for i in tqdm(range(0,len(labels))):
img = Image.open(filenames[i])
img = img.resize((256,256))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"label":
tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]]),
'img_raw':
tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
makeTFRec(filenames, labels)
TFRecordWriter类
tf.train.Feature()特征
tf.trian.Features()和tf.train.Example()
4. 代码实现:读取TFRecord数据集,并将其转化为队列
函数read_and_decode支持两种模式的队列格式转换:
- 训练模式:对数据集进行乱序操作,并将其按照指定批次组合起来
- 测试模式:按照顺序读取数据集一次
def read_and_decode(filenames, flag='train', batch_size=3):
if flag == 'train':
filename_queue = tf.train.string_input_producer(filenames) #已经进行乱序读取了
else:
filename_queue = tf.trina.string_input_producer(filenames, nun_epochs=1, shuffle=False) #取一个批次,并且顺序
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features = {
'label': tf.FixedLenFeature([],tf.int64),
'img_raw': tf.FixEdLenFeature([],tf.string),
})
#tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [256,256,3])
label = tf.cast(features['label'],tf.int32)
if flag == 'train':
image = tf.cast(image,tf.float32) * (1. / 255) - 0.5 #训练是将其归一化
img_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, capacity=20)
return img_batch, label_batch
return image, label
TFRecordfilenames = ["mydata.tfrecord"]
image, label = read_and_decode(TFRecordfilenames, flag='test')
5. 代码实现:建立会话,将数据保存到文件
saveimgpath = 'show\\'
if tf.gfile.Exists(saveimgpath): #如果存在saveimgpath,则将其删除
tf.gfile.DeleteRecursively(saveimgpath)
tf.gfile.MakeDirs(saveimgpath)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
myset = set([]) #建立集合对象,用于存放子文件夹
try:
print("start")
i = 0
while True:
example, examplelab = sess.run([image, label])
print("2")
examplelab = str(examplelab)
if examplelab not in myset:
myset.add(examplelab)
tf.gfile.MakeDirs(saveimgpath+examplelab)
img = Image.fromarray(example, 'RGB') #转换成image格式
img.save(saveimgpath+examplelab + '/' + str(i) + '_Label' + '.jpg') #保存图片
print(i)
i = i + 1
except tf.errors.OutOfRangeError:
print("Done Test -- epoch limit reached")
finally:
coord.request_stop()
print("stop()")
coord.join(threads)
print("stop()")
sess.close()
‘utf-8’ codec can’t decode byte 0xd5 in position 105: invalid continuation byte错误