由于实验过程中需要随机划分部分的训练集和测试集,编写了以下代码,供日后使用作参考
import glob,os
from os.path import join,basename
import shutil
import random
def mycopyfile(srcfile, dstpath): # 复制函数
if not os.path.isfile(srcfile):
print("%s not exist!" % (srcfile))
else:
fpath, fname = os.path.split(srcfile) # 分离文件名和路径
if not os.path.exists(dstpath):
os.makedirs(dstpath) # 创建路径
shutil.copy(srcfile, dstpath + fname) # 复制文件
print("copy %s -> %s" % (srcfile, dstpath + fname))
def transfrom():
with open("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train.txt",'r') as f:
for name in f.readlines():
namesp=name.split('\n')[0]
mycopyfile(join("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train3",namesp+"_img.tif"),'/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train/')
with open("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/test.txt",'r') as f:
for name in f.readlines():
namesp=name.split('\n')[0]
mycopyfile(join("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train3",namesp+"_img.tif"),'/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/test/')
with open("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/val.txt",'r') as f:
for name in f.readlines():
namesp=name.split('\n')[0]
mycopyfile(join("/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/train3",namesp+"_img.tif"),'/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/image1/val/')
def split():
val_percent = 0.15
test_percent = 0.15
train_percent = 0.70
allpath="/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/mask1/train3"
total_xml = os.listdir(allpath)
num = len(total_xml) # 统计所有的标注文件
list = range(num)
tr = int(num * 0.85) # 设置训练和验证集的数目
tv = int(num * 0.15) # 设置训练集的数目
te = num-tr-tv
trainval = random.sample(list, tr)
val = random.sample(trainval, tv)
ftest = open('/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/mask1/test.txt', 'w')
ftrain = open('/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/mask1/train.txt', 'w')
fval = open('/mnt/sdb3/liutiancheng/IKDNet-pytorch-main/dataset/N3C-California/mask1/val.txt', 'w')
trainnum=0
testnum=0
valnum=0
for i in list:
name = total_xml[i][:-7] + '\n'
print(name)
if i in trainval:
if i in val:
fval.write(name)
valnum+=1
else:
ftrain.write(name)
trainnum+=1
else:
ftest.write(name)
testnum+=1
ftrain.close()
fval.close()
ftest.close()
if __name__ == '__main__':
split()
transfrom()