本文以flower分类数据集为例,数据集存储格式为:以类别名命名文件夹,将不同类别图像存储在里面
1.根据下载好的分类数据集进行随机分割,读取,保存
linux系统下输入一下命令:
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz
windows直接输入下面链接下载:
http://download.tensorflow.org/example_images/flower_photos.tgz
def get_dataset_dict(imagedir, train_percentage=8):
rootdir = imagedir
category = [x[1] for x in os.walk(imagedir)][0]
dataset = {}
label = {}
for j, class_name in enumerate(category):
subdir = os.path.join(rootdir, class_name)
imagelist = os.listdir(subdir)
number = len(imagelist)
label[class_name] = j
train_dataset = []
test_dataset = []
for i, image in enumerate(imagelist):
r = random.randint(0, number)
if r < number / 10