PyTorch学习笔记(18)–划分训练集和测试集的脚本文件
本博文是PyTorch的学习笔记,第18次内容记录,主要记录了如何自动的划分训练集和测试集。主要包括了2种方式,第1种方式针对的是数据集是按照类别存放在多个文件夹中,适用于分类问题,将同一类的图片划分为训练集和测试集,第2种方式针对数据不按照分类存放,而是直接放在同一个文件夹下,将数据分成训练集和测试集。
1.按分类存放
在进行训练集与测试集划分时,需要划分的文件夹是:flower_data/flower_photos,下面有5个分类的文件夹,分别为:daisy、dandelion、roses、sunflowers、tulips,进行分类的脚本为split_data.py,脚本文件split_data.py与数据文件夹flower_photos是并列的关系,都放在flower_data文件夹下,如下图所示:
脚本文件split_data.py的代码如下:
# coding :UTF-8
# 文件功能: 代码实现自动将数据集划分为训练集和验证集的功能
# 开发人员: XXX
# 开发时间: 2021/12/3 6:07 下午
# 文件名称: split_data.py
# 开发工具: PyCharm
import os
from shutil import copy, rmtree
import random
def mk_file(file_path: str):
if os.path.exists(file_path):
# 如果文件夹存在,则先删除原文件夹再重新创建
rmtree(file_path)
os.makedirs(file_path)
def main():
# 保证随机可复现
random.seed(0)
# 将数据集中10%的数据划分到验证集中
split_rate = 0.1
# 指向你解压后的flower_photos文件夹
cwd = os.getcwd() # 用于返回当前工作目录
data_root = os.path.join(cwd, "flower_data")
origin_flower_path = os.path.join(data_root, "flower_photos")
assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)
flower_class = [cla for cla in os.listdir(origin_flower_path)
if os.path.isdir(os.path.join(origin_flower_path, cla))]
# 建立保存训练集的文件夹
train_root = os.path.join(data_root, "train")
mk_file(train_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(train_root, cla))
# 建立保存验证集的文件夹
val_root = os.path.join(data_root, "val")
mk_file(val_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(val_root, cla))
for cla in flower_class:
cla_path = os.path.join(origin_flower_path, cla)
images = os.listdir(cla_path)
num = len(images)
# 随机采样验证集的索引
eval_index = random.sample(images, k=int(num * split_rate))
for index, image in enumerate(images):
if image in eval_index:
# 将分配至验证集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(val_root, cla)
copy(image_path, new_path)
else:
# 将分配至训练集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(train_root, cla)
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="") # processing bar
print()
print("processing done!")
if __name__ == '__main__':
main()
2.所有的按一个文件夹存放
当所有的图像放在一个文件夹下存放时,将这个文件夹下的图像分成训练集和测试集,windows版本代码如下:
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 20 16:28:13 2021
@author: NN
"""
import os
import random
import shutil
# 原始数据集路径
# origion_path = r'D:\蓝藻门'
origion_path = r'E:\BaiduNetdiskDownload\data_20211112_train_new'
names = os.listdir(origion_path)
# 保存路径
# save_train_dir = r'D:\藻类识别神经网络\分类网络\train'
# save_test_dir = r'D:\藻类识别神经网络\分类网络\test'
# 数据集类别及数量
for i in names:
file_list = origion_path + '\\' + i
image_list = os.listdir(file_list) # 获取图片的原始路径
image_number = len(image_list)
train_number = int(image_number * 0.75)
train_sample = random.sample(image_list, train_number) # 从image_list中随机获取0.8比例的图像.
test_sample = list(set(image_list) - set(train_sample))
# 创建保存路径
save_train_dir = r'D:\藻类数据\data_20211112\train' + '\\' + i
save_test_dir = r'D:\藻类数据\data_20211112\test' + '\\' + i
if not os.path.isdir(save_train_dir):
os.makedirs(save_train_dir)
if not os.path.isdir(save_test_dir):
os.makedirs(save_test_dir)
# 复制图像到目标文件夹
for j in train_sample:
shutil.copy(file_list + '\\' + j, save_train_dir)
for k in test_sample:
shutil.copy(file_list + '\\' + k, save_test_dir)