文件结构
- 数据集data_dir文件结构:
├── data_dir: 数据集图片所在目录(可包含其他合适文件,本脚本只对图片进行统计)
├── 1.jpg: 图片1
├── 2.jpg: 图片2
└── n.jpg: 图片n
- 生成文件结构:划分后数据集保存save_dir下
├── save_dir: 划分后图片保存目录
├── train: 训练集
├── 1.jpg: 图片1
└── n.jpg: 图片n
├── val: 验证集
└── test: 测试集(可根据需要选择是否划分测试集)
代码实现
- 检查数据集data_dir是否存在,否则报错;
assert os.path.exists(data_dir), \
print('{} does not exist, please input again!'.format(data_dir))
- 检查保存路径是否存在,否则创建;
def mkdir(check_dir):
if not os.path.exists(check_dir):
os.makedirs(check_dir)
- 随机打乱数据集图片;
img_list = os.listdir(data_dir)
img_list = [os.path.join(data_dir, img) for img in img_list if img.endswith('.jpg')]
random.shuffle(img_list)
- 根据划分比例计算训练集、验证集和测试集(若需)大小;
if len(proportion) == 1:
train_size = int(len(img_list) * proportion[0])
elif len(proportion) == 2:
train_size = int(len(img_list) * proportion[0])
val_size = int(len(img_list) * proportion[1])
- 将数据集划分并保存至对应目录
if len(proportion) == 1:
train_list = img_list[:train_size]
val_list = img_list[train_size:]
for img in train_list:
shutil.copy(img, os.path.join(save_dir, "train"))
for img in val_list:
shutil.copy(img, os.path.join(save_dir, "val"))
elif len(proportion) == 2:
train_list = img_list[:train_size]
val_list = img_list[train_size:train_size+val_size]
test_list = img_list[train_size+val_size:]
for img in train_list:
shutil.copy(img, os.path.join(save_dir, "train"))
for img in val_list:
shutil.copy(img, os.path.join(save_dir, "val"))
for img in test_list:
shutil.copy(img, os.path.join(save_dir, "test"))
完整代码
import os
import random
import shutil
def mkdir(check_dir):
if not os.path.exists(check_dir):
os.makedirs(check_dir)
def split_dataset(data_dir, proportion, save_dir):
assert os.path.exists(data_dir), \
print('{} does not exist, please input again!'.format(data_dir))
mkdir(save_dir)
mkdir(os.path.join(save_dir, "train"))
mkdir(os.path.join(save_dir, "val"))
if len(proportion) == 2:
mkdir(os.path.join(save_dir, "test"))
img_list = os.listdir(data_dir)
img_list = [os.path.join(data_dir, img) for img in img_list if img.endswith('.jpg')]
random.shuffle(img_list)
if len(proportion) == 1:
train_size = int(len(img_list) * proportion[0])
elif len(proportion) == 2:
train_size = int(len(img_list) * proportion[0])
val_size = int(len(img_list) * proportion[1])
if len(proportion) == 1:
train_list = img_list[:train_size]
val_list = img_list[train_size:]
for img in train_list:
shutil.copy(img, os.path.join(save_dir, "train"))
for img in val_list:
shutil.copy(img, os.path.join(save_dir, "val"))
elif len(proportion) == 2:
train_list = img_list[:train_size]
val_list = img_list[train_size:train_size+val_size]
test_list = img_list[train_size+val_size:]
for img in train_list:
shutil.copy(img, os.path.join(save_dir, "train"))
for img in val_list:
shutil.copy(img, os.path.join(save_dir, "val"))
for img in test_list:
shutil.copy(img, os.path.join(save_dir, "test"))
if __name__ == "__main__":
data_dir = './test_dir'
proportion = [0.6, 0.2]
save_dir = './test_result_dir'
split_dataset(data_dir, proportion, save_dir)
print('finished!')
使用方法
- 修改参数
data_dir : str,指定为待划分数据集路径,需存在否则报错
proportion :list,元素为1代表训练集所占比例,剩余则划分至验证集;元素为2则代表训练集、验证集比例,剩余则划分至测试集
save_dir :str,指定为划分后数据集保存路径,不存在自动创建