1.随机划分训练、验证、测试
import os
import glob
import random
import shutil
dataset_dir = './XXX_classification/'
train_dir = './datasets/train/'
valid_dir = './datasets/val/'
test_dir = './datasets/test/'
train_per = 0.8
valid_per = 0.1
test_per = 0.1
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
if __name__ == '__main__':
for root, dirs, files in os.walk(dataset_dir):
for sDir in dirs:
imgs_list = glob.glob(os.path.join(root, sDir)+'/*.jpg')
random.seed(666)
random.shuffle(imgs_list)
imgs_num = len(imgs_list)
train_point = int(imgs_num * train_per)
valid_point = int(imgs_num * (train_per + valid_per))
for i in range(imgs_num):
if i < train_point:
out_dir = train_dir + sDir + '/'
elif i < valid_point:
out_dir = valid_dir + sDir + '/'
else:
out_dir = test_dir + sDir + '/'
makedir(out_dir)
out_path = out_dir + os.path.split(imgs_list[i])[-1]
shutil.copy(imgs_list[i], out_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))
2.将划分好的训练、验证、测试生成meta格式
import os
from glob import glob
from pathlib import Path
def generate_mmcls_ann(data_dir, img_type='.jpg'):
data_dir = str(Path(data_dir)) + '/'
classes = ['0000', '0001', '0002', '0003']
class2id = dict(zip(classes, range(len(classes))))
data_dir = str(Path(data_dir)) + '/'
dir_types = ['train', 'val', 'test']
sub_dirs = os.listdir(data_dir)
ann_dir = data_dir + 'meta/'
if not os.path.exists(ann_dir):
os.makedirs(ann_dir)
for sd in sub_dirs:
if sd not in dir_types:
continue
annotations = []
target_dir = data_dir + sd + '/'
for d in os.listdir(target_dir):
class_id = str(class2id[d])
images = glob(target_dir + d + '/*' + img_type)
for img in images:
img = d + '/' + os.path.basename(img)
annotations.append(img + ' ' + class_id + '\n')
annotations[-1] = annotations[-1].strip()
with open(ann_dir + sd + '.txt', 'w') as f:
f.writelines(annotations)
if __name__ == '__main__':
data_dir = './datasets/'
generate_mmcls_ann(data_dir)