按照训练集80%,测试集20%
因为文件夹里的图片有很多重名文件,生成训练集测试集的时候图片会覆盖。为了避免对代码做了调正
import codecs
import os
import random
import shutil
from PIL import Image
train_ratio = 4.0 / 5#80%
all_file_dir = './xh'#文件目录xh文件夹里有上述待处理的文件夹
class_list = [c for c in os.listdir(all_file_dir) if
os.path.isdir(os.path.join(all_file_dir, c)) and not c.endswith('Set') and not c.startswith('.')]
# 用于返回文件夹的名称
class_list.sort() # 排序
print(class_list)
train_image_dir = os.path.join(all_file_dir, "trainImageSet") # 这一行是路径的拼接trainImageSet放在all_file_dir路径下
if not os.path.exists(train_image_dir): # 查看文件是否存在
os.makedirs(train_image_dir)
eval_image_dir = os.path.join(all_file_dir, "evalImageSet")
if not os.path.exists(eval_image_dir):
os.makedirs(eval_image_dir)
train_file = codecs.open(os.path.join(all_file_dir, "train.txt"), 'w') # 打开文件,w为只读
eval_file = codecs.open(os.path.join(all_file_dir, "eval.txt"), 'w')
with codecs.open(os.path.join(all_file_dir, "label_list.txt"), "w") as label_list:
label_id = 0
for class_dir in class_list:
label_list.write("{0}\t{1}\n".format(label_id, class_dir)) # label_id, class_dir存在label_list上
image_path_pre = os.path.join(all_file_dir, class_dir)
for file in os.listdir(image_path_pre): # 返回这个目录下有什么文件
try:
img = Image.open(os.path.join(image_path_pre, file))
if random.uniform(0, 1) <= train_ratio:
shutil.copyfile(os.path.join(image_path_pre, file),
os.path.join(train_image_dir, (class_dir + file))) # 复制一个文件到另一个文件中,(class_dir + file)为了
避免每个文件夹的重名文件,在文件名前加上类别名
train_file.write("{0} {1}\n".format(os.path.join("trainImageSet", (class_dir + file)), label_id))
else:
shutil.copyfile(os.path.join(image_path_pre, (class_dir + file)), os.path.join(eval_image_dir, file))
eval_file.write("{0} {1}\n".format(os.path.join("evalImageSet", (class_dir + file)), label_id))
except Exception as e:
pass
# 存在一些文件打不开,此处需要稍作清洗
label_id += 1
train_file.close()
eval_file.close()
得到的文件:
代码参考了:基于PaddleClas的热轧钢带表面缺陷分类
https://aistudio.baidu.com/aistudio/projectdetail/685319