1 #*_*coding: utf-8 *_*
2 #Author --LiMing--
3
4 importos5 importrandom6 importshutil7 importtime8
9 defcopyFile(fileDir, class_name):10 image_list = os.listdir(fileDir) #获取图片的原始路径
11 image_number =len(image_list)12
13 train_number = int(image_number *train_rate)14 train_sample = random.sample(image_list, train_number) #从image_list中随机获取0.8比例的图像.
15 test_sample = list(set(image_list) -set(train_sample))16 sample =[train_sample, test_sample]17
18 #复制图像到目标文件夹
19 for k inrange(len(save_dir)):20 if os.path.isdir(save_dir[k] +class_name):21 for name insample[k]:22 shutil.copy(os.path.join(fileDir, name), os.path.join(save_dir[k] + class_name+'/', name))23 else:24 os.makedirs(save_dir[k] +class_name)25 for name insample[k]:26 shutil.copy(os.path.join(fileDir, name), os.path.join(save_dir[k] + class_name+'/', name))27
28 if __name__ == '__main__':29 time_start =time.time()30
31 #原始数据集路径
32 origion_path = '/home/room/lm_other/NWPU-RESISC45/'
33
34 #保存路径
35 save_train_dir = '/home/room/lm_other/RS_45/2_8/train/'
36 save_test_dir = '/home/room/lm_other/RS_45/2_8/test/'
37 save_dir =[save_train_dir, save_test_dir]38
39 #训练集比例
40 train_rate = 0.2
41
42 #数据集类别及数量
43 file_list =os.listdir(origion_path)44 num_classes =len(file_list)45
46 for i inrange(num_classes):47 class_name =file_list[i]48 image_Dir =os.path.join(origion_path, class_name)49copyFile(image_Dir, class_name)50 print('%s划分完毕!' %class_name)51
52 time_end =time.time()53 print('---------------')54 print('训练集和测试集划分共耗时%s!' % (time_end - time_start)