【目标检测】将目标检测数据集划分为训练集、验证集与测试集 python代码

文件结构

  1. 数据集data_dir文件结构:
├── data_dir: 数据集图片所在目录(可包含其他合适文件,本脚本只对图片进行统计)
    ├── 1.jpg: 图片1
    ├── 2.jpg: 图片2
    └── n.jpg: 图片n
  1. 生成文件结构:划分后数据集保存save_dir下
├── save_dir: 划分后图片保存目录
    ├── train: 训练集
    	    ├── 1.jpg: 图片1
		    └── n.jpg: 图片n
    ├── val:   验证集
    └── test:  测试集(可根据需要选择是否划分测试集)

代码实现

  1. 检查数据集data_dir是否存在,否则报错;
assert os.path.exists(data_dir), \
    print('{} does not exist, please input again!'.format(data_dir))
  1. 检查保存路径是否存在,否则创建;
def mkdir(check_dir):
    if not os.path.exists(check_dir):
        os.makedirs(check_dir)
  1. 随机打乱数据集图片;
# 获取数据集中所有的图片路径
img_list = os.listdir(data_dir)
img_list = [os.path.join(data_dir, img) for img in img_list if img.endswith('.jpg')]
# 随机打乱数据集顺序
random.shuffle(img_list)
  1. 根据划分比例计算训练集、验证集和测试集(若需)大小;
    # 计算划分后的数据集大小
    if len(proportion) == 1:
        train_size = int(len(img_list) * proportion[0])
    elif len(proportion) == 2:
        train_size = int(len(img_list) * proportion[0])
        val_size = int(len(img_list) * proportion[1])
  1. 将数据集划分并保存至对应目录
 # 将数据集划分为训练集、验证集和测试集(若有)
    if len(proportion) == 1:
        train_list = img_list[:train_size]
        val_list = img_list[train_size:]
        for img in train_list:
            shutil.copy(img, os.path.join(save_dir, "train"))
        for img in val_list:
            shutil.copy(img, os.path.join(save_dir, "val"))
    elif len(proportion) == 2:
        train_list = img_list[:train_size]
        val_list = img_list[train_size:train_size+val_size]
        test_list = img_list[train_size+val_size:]
        for img in train_list:
            shutil.copy(img, os.path.join(save_dir, "train"))
        for img in val_list:
            shutil.copy(img, os.path.join(save_dir, "val"))
        for img in test_list:
            shutil.copy(img, os.path.join(save_dir, "test"))

完整代码

import os
import random
import shutil

def mkdir(check_dir):
    if not os.path.exists(check_dir):
        os.makedirs(check_dir)

def split_dataset(data_dir, proportion, save_dir):
    assert os.path.exists(data_dir), \
    print('{} does not exist, please input again!'.format(data_dir))

    # 检查保存路径是否存在,若不存在则创建
    mkdir(save_dir)
    mkdir(os.path.join(save_dir, "train"))
    mkdir(os.path.join(save_dir, "val"))
    if len(proportion) == 2:
        mkdir(os.path.join(save_dir, "test"))

    # 获取数据集中所有的图片路径
    img_list = os.listdir(data_dir)
    img_list = [os.path.join(data_dir, img) for img in img_list if img.endswith('.jpg')]

    # 随机打乱数据集顺序
    random.shuffle(img_list)

    # 计算划分后的数据集大小
    if len(proportion) == 1:
        train_size = int(len(img_list) * proportion[0])
    elif len(proportion) == 2:
        train_size = int(len(img_list) * proportion[0])
        val_size = int(len(img_list) * proportion[1])

    # 将数据集划分为训练集、验证集和测试集(若有)
    if len(proportion) == 1:
        train_list = img_list[:train_size]
        val_list = img_list[train_size:]
        for img in train_list:
            shutil.copy(img, os.path.join(save_dir, "train"))
        for img in val_list:
            shutil.copy(img, os.path.join(save_dir, "val"))
    elif len(proportion) == 2:
        train_list = img_list[:train_size]
        val_list = img_list[train_size:train_size+val_size]
        test_list = img_list[train_size+val_size:]
        for img in train_list:
            shutil.copy(img, os.path.join(save_dir, "train"))
        for img in val_list:
            shutil.copy(img, os.path.join(save_dir, "val"))
        for img in test_list:
            shutil.copy(img, os.path.join(save_dir, "test"))
if __name__ == "__main__":
    data_dir = './test_dir'			# 数据集目录
    proportion = [0.6, 0.2]			# 划分比例
    save_dir = './test_result_dir'	# 保存路径
    split_dataset(data_dir, proportion, save_dir)
    print('finished!')

使用方法

  1. 修改参数
    data_dir : str,指定为待划分数据集路径,需存在否则报错
    proportion :list,元素为1代表训练集所占比例,剩余则划分至验证集;元素为2则代表训练集、验证集比例,剩余则划分至测试集
    save_dir :str,指定为划分后数据集保存路径,不存在自动创建
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值