1. 背景说明
涉及单分类,多分类问题,共计5W多样本,43个类,以及一个标注文件(包括文件名、宽、高以及坐标)。其中每类图片数量不均等,图片尺寸也不尽相同。
数据源:German Traffic Sign Benchmarks
下载训练与测试数据目录如下:
这里只识别“STOP”、“禁止通行”、“直行”、“环岛行驶”四个类别。对应训练目录下的类别文件夹分别为00014, 00017, 00035, 00040。
2. 数据集切分
1. 将Final_training 目录按7:2:1切分为训练集、验证集、测试集三个文件夹,每个文件夹又分为“STOP”、“禁止通行”、“直行”、“环岛行驶”四个子目录。
11
import shutil
from pathlib import Path
from glob import glob
import numpy as np
def split_train_val_test_dataset(data_dir, data_sets, class_names, class_indices, train_folder):
# 1. 创建对应目录
for dt in data_sets:
for cls in class_names:
# exist_ok=True时,在目录已存在的情况下,不会触发FileExistsError异常
(data_dir/dt/cls).mkdir(parents=True, exist_ok=True)
# 2. 将原始数据集进行切分,并拷贝图片到目标文件夹
for i, cls_index in enumerate(class_indices):
img_paths = np.array(glob(f'{train_folder[int(cls_index)]}/*.ppm'))
class_name = class_names[i] # 标签
print(f'{class_name}: {len(img_paths)}')
np.random.shuffle(img_paths) # 打乱图片路径
# 对img_paths进行切分,本质上是索引切分,indices_or_sections定义切分点(0.7和0.9)
ds_split = np.split(
img_paths,
indices_or_sections=[int(0.7*len(img_paths)), int(0.9*len(img_paths))]
)
dataset = zip(data_sets, ds_split) # 拼接
for dt, img_paths in dataset:
print(f'\t{dt}, {len(img_paths)}')
for path in img_paths:
shutil.copy(path, f'{data_dir}/{dt}/{class_name}/')
2. 调用