使用背景
当训练数据的时候,找到一个数据集,专业的数据集里面并不会帮你划分好训练集测试集等。目的是给用户更好地自定义使用。而我们训练模型之前必须分出训练集和测试集。据了解pytorch 训练库没有提供关于数据集划分的处理方法。以下是一个简单的脚本,只需要输入训练集的路径,以及调节训练集测试集的比例,即可分出两个数据集,简单实用。
具体办法
很简单,几个步骤
- 在数据集所在的路径同级目录创建一个train,test文件夹。
- 如果是分类模型,继续在 train,test文件夹下创建分类的文件夹。
- 将数据集打乱后按照比例复制粘贴到train,test文件夹。
代码
直接上代码:
# -*- coding: utf-8 -*-
import os
import random
import shutil
#用于复制粘贴
# 将数据集分为两份,重新的复制粘贴到一个新的文件夹之后
def extract_set(trainSet_rate:float,org_path:str):
class_path = os.listdir(path) # 分类的文件夹,用于遍历划分和创建新的文件夹
# ==============================================================
# 创建文件夹,必须是在于所有数据集的文件上一级 创建 train test ,
# 已经创建了 先注释掉,除了新建一个 test ,train
# 文件夹之外 还需要继续创建train,test 下面的分类的文件夹
upper_path = os.path.abspath(os.path.dirname(path)) # 上一级的目录
if not os.path.exists(os.path.join(upper_path, r"train")):
# print("no exist")
os.mkdir((os.path.join(upper_path, "train")))
if not os.path.exists(os.path.join(upper_path, r"test")):
# print("no exist")
os.mkdir(os.path.join(upper_path, "test"))
# 创建原有的数据集中每个分类的文件夹
for i in class_path:
if not os.path.exists(os.path.join(upper_path, "train/{}".format(i))):
# print("no exist")
os.mkdir((os.path.join(upper_path, "train/{}".format(i))))
if not os.path.exists(os.path.join(upper_path, "test/{}".format(i))):
#print("no exist")
os.mkdir(os.path.join(upper_path, "test/{}".format(i)))
# ==============================================================
for i in class_path:
content_name_list = os.listdir(os.path.join(path, i)) # 未分配数据集之前训练数据文件名,非路径)
total = len(content_name_list) # 总数量
train_len = int(total * trainSet_rate) # 分配训练集和测试集
random.shuffle(content_name_list) # 将数据集随机错乱
trainSet_path = content_name_list[:train_len:] # 训练集的所有数据
testSet_path = content_name_list[train_len::] # 测试集的所有数
#复制
for j in trainSet_path:
shutil.copy(os.path.join(path, "{}/{}".format(i, j)), os.path.join(upper_path, "train/{}/{}".format(i, j)))
for j in testSet_path:
shutil.copy(os.path.join(path, "{}/{}".format(i, j)), os.path.join(upper_path, "test/{}/{}".format(i, j)))
print("数据集划分已完成!")
if __name__ == '__main__':
trainSet_rate = 0.7 #训练集:测试集 = 7:3
path = "..../ten_annimals/raw-img" #初始数据集路径
extract_set(trainSet_rate=trainSet_rate,org_path=path)