基于python的训练集测试集划分(个人记录)

由于实验过程中需要随机划分部分的训练集和测试集,编写了以下代码,供日后使用作参考

import glob,os
from os.path import join,basename
import shutil
import random


def mycopyfile(srcfile, dstpath):  # 复制函数
    if not os.path.isfile(srcfile):
        print("%s not exist!" % (srcfile))
    else:
        fpath, fname = os.path.split(srcfile)  # 分离文件名和路径
        if not os.path.exists(dstpath):
            os.makedirs(dstpath)  # 创建路径
        shutil.copy(srcfile, dstpath + fname)  # 复制文件
        print("copy %s -> %s" % (srcfile, dstpath + fname))
def transfrom():
    with open("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train.txt",'r') as f:
        for name in f.readlines():
            namesp=name.split('\n')[0]
            mycopyfile(join("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train3",namesp+"_img.tif"),'/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train/')
    with open("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/test.txt",'r') as f:
        for name in f.readlines():
            namesp=name.split('\n')[0]
            mycopyfile(join("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train3",namesp+"_img.tif"),'/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/test/')
    with open("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/val.txt",'r') as f:
        for name in f.readlines():
            namesp=name.split('\n')[0]
            mycopyfile(join("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train3",namesp+"_img.tif"),'/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/val/')
def split():
    val_percent = 0.15
    test_percent = 0.15
    train_percent = 0.70

    allpath="/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/mask1/train3"
    total_xml = os.listdir(allpath)
    num = len(total_xml)  # 统计所有的标注文件
    list = range(num)
    tr = int(num * 0.85)  # 设置训练和验证集的数目
    tv = int(num * 0.15)  # 设置训练集的数目
    te = num-tr-tv
    trainval = random.sample(list, tr)
    val = random.sample(trainval, tv)

    ftest = open('/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/mask1/test.txt', 'w')
    ftrain = open('/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/mask1/train.txt', 'w')
    fval = open('/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/mask1/val.txt', 'w')
    trainnum=0
    testnum=0
    valnum=0
    for i in list:
        name = total_xml[i][:-7] + '\n'
        print(name)
        if i in trainval:
            if i in val:
                fval.write(name)
                valnum+=1
            else:
                ftrain.write(name)
                trainnum+=1
        else:
            ftest.write(name)
            testnum+=1
    ftrain.close()
    fval.close()
    ftest.close()


if __name__ == '__main__':
    split()
    transfrom()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值