如何划分训练集还有测试集

按照训练集80%,测试集20%
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

因为文件夹里的图片有很多重名文件,生成训练集测试集的时候图片会覆盖。为了避免对代码做了调正

import codecs
import os
import random
import shutil
from PIL import Image

train_ratio = 4.0 / 5#80%
all_file_dir = './xh'#文件目录xh文件夹里有上述待处理的文件夹
class_list = [c for c in os.listdir(all_file_dir) if
              os.path.isdir(os.path.join(all_file_dir, c)) and not c.endswith('Set') and not c.startswith('.')]
# 用于返回文件夹的名称
class_list.sort()  # 排序
print(class_list)
train_image_dir = os.path.join(all_file_dir, "trainImageSet")  # 这一行是路径的拼接trainImageSet放在all_file_dir路径下
if not os.path.exists(train_image_dir):  # 查看文件是否存在
    os.makedirs(train_image_dir)

eval_image_dir = os.path.join(all_file_dir, "evalImageSet")
if not os.path.exists(eval_image_dir):
    os.makedirs(eval_image_dir)

train_file = codecs.open(os.path.join(all_file_dir, "train.txt"), 'w')  # 打开文件,w为只读
eval_file = codecs.open(os.path.join(all_file_dir, "eval.txt"), 'w')

with codecs.open(os.path.join(all_file_dir, "label_list.txt"), "w") as label_list:
    label_id = 0
    for class_dir in class_list:
        label_list.write("{0}\t{1}\n".format(label_id, class_dir))  # label_id, class_dir存在label_list上
        image_path_pre = os.path.join(all_file_dir, class_dir)
        for file in os.listdir(image_path_pre):  # 返回这个目录下有什么文件
            try:
                img = Image.open(os.path.join(image_path_pre, file))
                if random.uniform(0, 1) <= train_ratio:
                    shutil.copyfile(os.path.join(image_path_pre, file),
                                    os.path.join(train_image_dir, (class_dir + file)))  # 复制一个文件到另一个文件中,(class_dir + file)为了
                                    避免每个文件夹的重名文件,在文件名前加上类别名
           
                    train_file.write("{0} {1}\n".format(os.path.join("trainImageSet", (class_dir + file)), label_id))
                else:
                    shutil.copyfile(os.path.join(image_path_pre, (class_dir + file)), os.path.join(eval_image_dir, file))
                    eval_file.write("{0} {1}\n".format(os.path.join("evalImageSet", (class_dir + file)), label_id))
            except Exception as e:
                pass
                # 存在一些文件打不开,此处需要稍作清洗
        label_id += 1

train_file.close()
eval_file.close()

得到的文件:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
代码参考了:基于PaddleClas的热轧钢带表面缺陷分类
https://aistudio.baidu.com/aistudio/projectdetail/685319

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值