1首先准备好自己图片数据集,我是用cifar10 的其中五个类别,分别是bird,car ,cat,deer,plane。五个类别数据分开放置。例如:
然后就是根据数据集生成tfrecord,生成的是protobuf(二进制文件,加速文件传输和处理速度),代码如下
<
import tensorflow as tf
import os
import random
import math
import sys
#验证集数量
_NUM_TEST = 500
#随机种子
_RANDOM_SEED = 0
#数据块
_NUM_SHARDS = 5
#数据集路径
DATASET_DIR = "D:/Tensorflow/slim/images/"
#标签文件名字
LABELS_FILENAME = "D:/Tensorflow/slim/images/labels.txt"
#定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
#判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
for split_name in ['train', 'test']:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True
#获取所有文件以及分类
def _get_filenames_and_classes(d