1、获取分类trian目录下各类别的图像数量
import os
dir = r"./data/train"
dirs = os.listdir(dir)
for i in range(len(dirs)):
try:
imgs = os.listdir(os.path.join(dir, dirs[i]))
print(dirs[i]+":"+str(len(imgs)))
except:
pass
2、将train数据集按照比例随机移动至test中,还可以移动回去!
import os
import random
import cv2
import shutil
def get_txt(data_path, save_path, trainval_percent, classname, type):
src_dir = os.path.join(data_path, classname)
des_dir = os.path.join(save_path, classname)
if not os.path.exists(des_dir):
os.makedirs(des_dir)
if type == 0:
total_img = os.listdir(src_dir)
num = len(total_img)
list1 = range(num)
tv = int(num * trainval_percent)
trainval = random.sample(list1, tv)
for i in range(num):
name = total_img[i][:-4]
if i in trainval:
pass
else:
src_path = os.path.join(src_dir, name+".jpg")
des_path = os.path.join(des_dir, name+".jpg")
img = cv2.imread(src_path)
print(name)
cv2.imshow("img", img)
cv2.waitKey(1)
shutil.move(src_path, des_path)
elif type == 1:
total_img = os.listdir(des_dir)
for i in range(len(total_img)):
name = total_img[i][:-4]
src_path = os.path.join(src_dir, name + ".jpg")
des_path = os.path.join(des_dir, name + ".jpg")
img = cv2.imread(des_path)
cv2.imshow("img", img)
cv2.waitKey(1)
shutil.move(des_path, src_path)
if __name__ == '__main__':
data_dir = r"./classify_data"
data_path = os.path.join(data_dir, "train")
save_path = os.path.join(data_dir, "test")
names = os.listdir(data_path)
className = []
for i in range(len(names)):
className.append(names[i])
trainval_percent = 0.8
for i in range(len(className)):
get_txt(data_path, save_path, trainval_percent, className[i], 0)