1.前言
为了后续统一数据集进行测试,本系列统一采用开源的花分类数据集,下载地址为
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
(有可能因为境外服务器的原因下载失败,此处给出flower_photos.tgz百度网盘分享链接:https://pan.baidu.com/s/1OdyKGBQ7wZI_52xAvEI2FQ
提取码:8idw
--来自百度网盘超级会员V4的分享)
2.数据预处理
解压下载到的压缩包后,进行数据预处理,按照一定比例(一般为8:2或9:1)拆分数据集
具体代码如下,视不同情况需要自己调整数据集的路径,如果相对路径不行可以尝试绝对路径
import os
from shutil import copy, rmtree
import random
def mk_file(file_path: str):
if os.path.exists(file_path):
rmtree(file_path)
os.makedirs(file_path)
def main():
random.seed(0)
# 划分比例
split_rate = 0.1
cwd = os.getcwd()
data_root = os.path.join(cwd, "flower_data")
origin_flower_path = os.path.join(data_root, "flower_photos")
assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)
flower_class = [cla for cla in os.listdir(origin_flower_path)
if os.path.isdir(os.path.join(origin_flower_path, cla))]
# 训练集文件夹
train_root = os.path.join(data_root, "train")
mk_file(train_root)
for cla in flower_class:
mk_file(os.path.join(train_root, cla))
# 验证集文件夹
val_root = os.path.join(data_root, "val")
mk_file(val_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(val_root, cla))
for cla in flower_class:
cla_path = os.path.join(origin_flower_path, cla)
images = os.listdir(cla_path)
num = len(images)
# 随机采样
eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if image in eval_index:
image_path = os.path.join(cla_path, image)
new_path = os.path.join(val_root, cla)
copy(image_path, new_path)
else:
image_path = os.path.join(cla_path, image)
new_path = os.path.join(train_root, cla)
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
print()
print("processing done!")
if __name__ == '__main__':
main()
运行后数据集目录如图所示: