利用文件夹分类图像,来把我们所需要用到的图像数据集转成相应的TFRecord格式文件,以便我们后续的使用。
文件夹的名字为标签名字,相应的文件夹里面存放此类标签的数据图像。
文件夹格式如下:
"0"-文件夹:
- img1.jpg
- img2.jpg
- img3.jpg
- ·····
"1"-文件夹:
- img1.jpg
- img2.jpg
- ·····
"2"-文件夹:
- img1.jpg
- img2.jpg
- ·····
(以此类推)
这里直接上代码(完整代码在最下面),代码的复用就只需更改前面的这块内容就可以了。
# 测试集数量
_NUM_TEST = 2000
# 随机种子
_RANDOM_SEED = 0
# 定义数据块数量
_NUM_SHARDS = 5
# 数据集路径
DATASET_DIR = "./data/Fnt"
# 标签文件存放名字
LABEL_FILENAME = "labels.txt"
变量 | 解释 |
---|---|
_NUM_TEST | 测试集数量,在整个数据集中随机抽取_NUM_TEST个数据充当测试集 |
_RANDOM_SEED | 随机种子,上面的随机用到,改不改都可以 |
_NUM_SHARDS | 碎片化数量,即最后会生成_NUM_SHARDS个TFRecord文件 |
_DATASET_DIR | 整个数据集文件夹存放的位置。我的数据是Fnt/0/img1.jpg,所以取到Fnt文件夹就可以了 |
LABEL_FILENAME | 最后会生成对应的label标签(0:a,1:b,2:c......) |
改了自己对应的数据集目录和对应的变量后,就可以运行了,保存的位置跟数据集的目录相同。
完整代码:
import tensorflow as tf
import os
import random
import sys
# 测试集数量
_NUM_TEST = 2000
# 随机种子
_RANDOM_SEED = 0
# 定义数据块数量
_NUM_SHARDS = 5
# 数据集路径
DATASET_DIR = "./data/Fnt"
# 标签文件存放名字
LABEL_FILENAME = "labels.txt"
# 定义tfrecord文件的路径+文字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'image-%s.tfrecords-%05d-of-%05d' % (split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
# 判断tfrecord 文件是否存在
def _dataset_exists(dataset_dir):
for split_name in ['train', 'test']:
for shard_id in range(_NUM_SHARDS):
# 定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True
# 获取所有文件以及分类
def _get_filenames_and_classes(dataset_dir):
# 数据目录
directories = []
# 分类名称
class_names = []
for filename in os.listdir(dataset_dir):
# 合并文件路径
path = os.path.join(dataset_dir, filename)
# 判断该路径是否为目录
if os.path.isdir(path):
# 加入数据目录
directories.append(path)
class_names.append(filename)
photo_filenames = []
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
photo_filenames.append(path)
return photo_filenames, class_names
def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def image_to_tfexample(image_data, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image': bytes_feature(image_data),
'label': int64_feature(class_id)
}))
def write_label_file(labels_to_class_names, dataset_dir, filename=LABEL_FILENAME):
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write("%d:%s\n" % (label, class_name))
def _convert_dataset(split_name, file_names, class_names_to_ids, dataset_dir):
assert split_name in ['train', 'test']
# 计算每个数据块有多少数据
num_per_shard = int(len(file_names)/_NUM_SHARDS)
with tf.Graph().as_default():
with tf.Session():
for shard_id in range(_NUM_SHARDS):
# 定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
# 每一个数据块的开始的位置
start_ndx = shard_id * num_per_shard
# 每一个数据块的最后的位置
end_ndx = min((shard_id+1)*num_per_shard, len(file_names))
for i in range(start_ndx, end_ndx):
try:
sys.stdout.write("\r>>[%s] Converting image %d/%d shard %d"
% (split_name, i+1, len(file_names), shard_id))
sys.stdout.flush()
# 读取图片
image_data = tf.gfile.FastGFile(file_names[i], 'rb').read()
# 获得图片类别名称
class_name = os.path.basename(os.path.dirname(file_names[i]))
# 找到类别名称对应的id
class_id = class_names_to_ids[class_name]
example = image_to_tfexample(image_data, class_id)
tfrecord_writer.write(example.SerializeToString())
except IOError as e:
print("Could not read:", file_names[i])
print("Error", e)
print("Skip~\n")
sys.stdout.write('\n')
sys.stdout.flush()
if __name__ == '__main__':
if _dataset_exists(DATASET_DIR):
print("tfrecord已存在")
else:
# 获取图片和分类
photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
# 把分类转为字典格式,类似于('a':0,'b':1,'c':2....)
class_name_to_ids = dict(zip(class_names, range(len(class_names))))
# 把数据切分为训练集和测试集
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames) # 打乱数组
traning_filenames = photo_filenames[_NUM_TEST:] # 500~ 为训练集
testing_filenames = photo_filenames[:_NUM_TEST] # 0~500 为测试集
# 数据转换
_convert_dataset('train', traning_filenames, class_name_to_ids, DATASET_DIR)
_convert_dataset('test', testing_filenames, class_name_to_ids, DATASET_DIR)
# 输出labels
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
write_label_file(labels_to_class_names, DATASET_DIR)
运行结果:
TFRecord文件(5个数量块):
(End)