这是一个花的数据集,标签为文件名分为daisy等,在这里我们希望分为train和val,其中train占90%,val占10%。
import os
from shutil import rmtree, copy
import random
# 如果文件存在保证先删除然后再创建
def mk_file(file_path):
if os.path.exists(file_path):
rmtree(file_path)
os.mkdir(file_path)
if __name__ == '__main__':
# 保证可复现
random.seed(0)
#获取当前地址
cwd = os.getcwd()
# 取总数的10%作为验证集
splitrate = 0.1
origin_flower_path = os.path.join(cwd, 'flower_data/flower_photos')
flower_class = [cla for cla in os.listdir(origin_flower_path)
if os.path.isdir(os.path.join(origin_flower_path, cla))]
# 新建train和val的文件夹
train_root = os.path.join(cwd, 'train')
mk_file(train_root)
for cla in flower_class:
mk_file(os.path.join(train_root, cla))
val_root = os.path.join(cwd, 'val')
mk_file(val_root)
for cla in flower_class:
mk_file(os.path.join(val_root, cla))
# 通过random.sample来获取总数0.1的图片名称进行划分
for cla in flower_class:
cla_path = os.path.join(origin_flower_path, cla)
images = os.listdir(cla_path)
number_images = len(images)
eval_index = random.sample(images, int(splitrate * number_images))
for index, image in enumerate(images):
if image in eval_index:
# 存放在验证文件夹
origin_path = os.path.join(cla_path, image)
new_path = os.path.join(val_root, cla)
copy(origin_path, new_path)
else:
# 将图片存放在训练文件夹
origin_path = os.path.join(cla_path, image)
new_path = os.path.join(train_root, cla)
copy(origin_path, new_path)
# 显示数据处理的进度
print("\r[{}] processing [{}/{}]".format(cla, index + 1, number_images), end="") # processing bar
print()
print('Process Done')
结果如下