1、源数据集介绍
数据来源于TensorFlow官网:花数据集。
原始数据集展示(每个文件夹内为文件名所对应花种类的图像):
在划分数据集时应注意LICENSE.txt文件的存在。
2、代码
将数据集按 7:3 的比例划分为训练集和验证集两部分。
若是读者对 os 模块不太熟悉,可见我另一篇博客:Python系列 | os模块常用命令
import os
import random
import shutil
def make_file(path):
if os.path.exists(path):
os.removedirs(path)
os.makedirs(path)
else:
os.makedirs(path)
def main():
path = r'.\dataset' # 训练集和验证集的存储路径
par_path = os.path.join(os.getcwd(), 'flower_photos') # 原始数据集路径
label_set = [label_i for label_i in os.listdir(par_path) if os.path.isdir(os.path.join(par_path, label_i))] # 类别标签
# 训练集和验证集的比例
rate = 0.7
# 创建训练集/测试集文件夹
train_path = os.path.join(path, 'train') # 训练集路径
make_file(train_path)
val_path = os.path.join(path, 'val') # 验证集路径
make_file(val_path)
for label in label_set:
# 创建训练集/验证集内的花种类文件
os.mkdir(os.path.join(train_path, label))
os.mkdir(os.path.join(val_path, label))
root = os.path.join(par_path, label)
image_set = os.listdir(root)
sample_index = random.sample(image_set, k=int(rate * len(image_set)))
# 分配数据
for i, j in enumerate(image_set):
origin_path = os.path.join(root, j)
tgt_train_path = os.path.join(train_path, label)
tgt_val_path = os.path.join(val_path, label)
if j in sample_index:
shutil.copy(origin_path, tgt_train_path)
else:
shutil.copy(origin_path, tgt_val_path)
if i == len(image_set) - 1:
print(f'< {label} > processing | {i + 1}/{len(image_set)} |')
print('Finished !')
if __name__ == '__main__':
main()
打印结果:
< daisy > processing | 633/633 |
< dandelion > processing | 898/898 |
< roses > processing | 641/641 |
< sunflowers > processing | 699/699 |
< tulips > processing | 799/799 |
Finished !
划分结果: