利用文件夹分类图像来生成TFRecord格式文件

利用文件夹分类图像,来把我们所需要用到的图像数据集转成相应的TFRecord格式文件,以便我们后续的使用。

文件夹的名字为标签名字,相应的文件夹里面存放此类标签的数据图像。

文件夹格式如下:

"0"-文件夹:

  • img1.jpg
  • img2.jpg
  • img3.jpg
  • ·····

"1"-文件夹:

  • img1.jpg
  • img2.jpg
  • ·····

"2"-文件夹:

  • img1.jpg
  • img2.jpg
  • ·····

(以此类推)

这里直接上代码(完整代码在最下面),代码的复用就只需更改前面的这块内容就可以了。

# 测试集数量
_NUM_TEST = 2000
# 随机种子
_RANDOM_SEED = 0
# 定义数据块数量
_NUM_SHARDS = 5
# 数据集路径
DATASET_DIR = "./data/Fnt"
# 标签文件存放名字
LABEL_FILENAME = "labels.txt"

 

变量解释
_NUM_TEST测试集数量,在整个数据集中随机抽取_NUM_TEST个数据充当测试集
_RANDOM_SEED随机种子,上面的随机用到,改不改都可以
_NUM_SHARDS碎片化数量,即最后会生成_NUM_SHARDS个TFRecord文件
_DATASET_DIR整个数据集文件夹存放的位置。我的数据是Fnt/0/img1.jpg,所以取到Fnt文件夹就可以了
LABEL_FILENAME最后会生成对应的label标签(0:a,1:b,2:c......)

改了自己对应的数据集目录和对应的变量后,就可以运行了,保存的位置跟数据集的目录相同。

 

完整代码:

import tensorflow as tf
import os
import random
import sys

# 测试集数量
_NUM_TEST = 2000
# 随机种子
_RANDOM_SEED = 0
# 定义数据块数量
_NUM_SHARDS = 5
# 数据集路径
DATASET_DIR = "./data/Fnt"
# 标签文件存放名字
LABEL_FILENAME = "labels.txt"


# 定义tfrecord文件的路径+文字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
    output_filename = 'image-%s.tfrecords-%05d-of-%05d' % (split_name, shard_id, _NUM_SHARDS)
    return os.path.join(dataset_dir, output_filename)


# 判断tfrecord 文件是否存在
def _dataset_exists(dataset_dir):
    for split_name in ['train', 'test']:
        for shard_id in range(_NUM_SHARDS):
            # 定义tfrecord文件的路径+名字
            output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
        if not tf.gfile.Exists(output_filename):
            return False
    return True


# 获取所有文件以及分类
def _get_filenames_and_classes(dataset_dir):
    # 数据目录
    directories = []
    # 分类名称
    class_names = []
    for filename in os.listdir(dataset_dir):
        # 合并文件路径
        path = os.path.join(dataset_dir, filename)
        # 判断该路径是否为目录
        if os.path.isdir(path):
            # 加入数据目录
            directories.append(path)
            class_names.append(filename)
    photo_filenames = []
    for directory in directories:
        for filename in os.listdir(directory):
            path = os.path.join(directory, filename)
            photo_filenames.append(path)
    return photo_filenames, class_names


def int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def image_to_tfexample(image_data, class_id):
    return tf.train.Example(features=tf.train.Features(feature={
        'image': bytes_feature(image_data),
        'label': int64_feature(class_id)
    }))


def write_label_file(labels_to_class_names, dataset_dir, filename=LABEL_FILENAME):
    labels_filename = os.path.join(dataset_dir, filename)
    with tf.gfile.Open(labels_filename, 'w') as f:
        for label in labels_to_class_names:
            class_name = labels_to_class_names[label]
            f.write("%d:%s\n" % (label, class_name))


def _convert_dataset(split_name, file_names, class_names_to_ids, dataset_dir):
    assert split_name in ['train', 'test']
    # 计算每个数据块有多少数据
    num_per_shard = int(len(file_names)/_NUM_SHARDS)
    with tf.Graph().as_default():
        with tf.Session():
            for shard_id in range(_NUM_SHARDS):
                # 定义tfrecord文件的路径+名字
                output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
                with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                    # 每一个数据块的开始的位置
                    start_ndx = shard_id * num_per_shard
                    # 每一个数据块的最后的位置
                    end_ndx = min((shard_id+1)*num_per_shard, len(file_names))
                    for i in range(start_ndx, end_ndx):
                        try:
                            sys.stdout.write("\r>>[%s] Converting image %d/%d shard %d"
                                             % (split_name, i+1, len(file_names), shard_id))
                            sys.stdout.flush()
                            # 读取图片
                            image_data = tf.gfile.FastGFile(file_names[i], 'rb').read()
                            # 获得图片类别名称
                            class_name = os.path.basename(os.path.dirname(file_names[i]))
                            # 找到类别名称对应的id
                            class_id = class_names_to_ids[class_name]
                            example = image_to_tfexample(image_data, class_id)
                            tfrecord_writer.write(example.SerializeToString())
                        except IOError as e:
                            print("Could not read:", file_names[i])
                            print("Error", e)
                            print("Skip~\n")
    sys.stdout.write('\n')
    sys.stdout.flush()


if __name__ == '__main__':
    if _dataset_exists(DATASET_DIR):
        print("tfrecord已存在")
    else:
        # 获取图片和分类
        photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
        # 把分类转为字典格式,类似于('a':0,'b':1,'c':2....)
        class_name_to_ids = dict(zip(class_names, range(len(class_names))))

        # 把数据切分为训练集和测试集
        random.seed(_RANDOM_SEED)
        random.shuffle(photo_filenames)  # 打乱数组
        traning_filenames = photo_filenames[_NUM_TEST:]  # 500~ 为训练集
        testing_filenames = photo_filenames[:_NUM_TEST]  # 0~500 为测试集

        # 数据转换
        _convert_dataset('train', traning_filenames, class_name_to_ids, DATASET_DIR)
        _convert_dataset('test', testing_filenames, class_name_to_ids, DATASET_DIR)

        # 输出labels
        labels_to_class_names = dict(zip(range(len(class_names)), class_names))
        write_label_file(labels_to_class_names, DATASET_DIR)

运行结果:

TFRecord文件(5个数量块):

 (End)

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值