在训练PaddleOCR方向分类模型之前,所有图片都在一个文件夹中,所有label信息都在同一个txt文件中,因此需要编写脚本,将其按照8:1:1的比例进行分割。
import os
import re
import shutil
import random
import argparse
def split_label(all_label, train_label, val_label, test_label):
f = open(all_label, 'r')
f_train = open(train_label, 'w')
f_val = open(val_label, 'w')
f_test = open(test_label, 'w')
raw_list = f.readlines()
num_train = int(len(raw_list) * 0.8)
num_val = int(len(raw_list) * 0.1)
num_test = int(len(raw_list) * 0.1)
random.shuffle(raw_list)
for i in range(num_train):
f_train.writelines(raw_list[i])
for i in range(num_train, num_train + num_val):
f_val.writelines(raw_list[i])
for i in range(num_train + num_val, num_train + num_val + num_test):
f_test.writelines(raw_list[i])
f.close()
f_train.close()
f_val.close()
f_test.close()
def split_img(all_imgs, train_label, train_imgs, val_label, val_imgs, test_label, test_imgs):
f_train = open(train_label, 'r')
f_val = open(val_label, 'r')
f_test = open(test_label, 'r')
train_list = f_train.readlines()
val_list = f_val.readlines()
test_list = f_test.readlines()
for i in range(len(train_list)):
img_path = os.path.join(all_imgs, re.split("[/\t]", train_list[i])[1])
shutil.move(img_path, train_imgs)
for i in range(len(val_list)):
img_path = os.path.join(all_imgs, re.split("[/\t]", val_list[i])[1])
shutil.move(img_path, val_imgs)
for i in range(len(test_list)):
img_path = os.path.join(all_imgs, re.split("[/\t]", test_list[i])[1])
shutil.move(img_path, test_imgs)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--all_label", default="../paddleocr/PaddleOCR/train_data/cls/cls_gt_train.txt")
parser.add_argument("--all_imgs_dir", default="../paddleocr/PaddleOCR/train_data/cls/images/")
parser.add_argument("--train_label", default="../paddleocr/PaddleOCR/train_data/cls/train.txt")
parser.add_argument("--train_imgs_dir", default="../paddleocr/PaddleOCR/train_data/cls/train/")
parser.add_argument("--val_label", default="../paddleocr/PaddleOCR/train_data/cls/val.txt")
parser.add_argument("--val_imgs_dir", default="../paddleocr/PaddleOCR/train_data/cls/val/")
parser.add_argument("--test_label", default="../paddleocr/PaddleOCR/train_data/cls/test.txt")
parser.add_argument("--test_imgs_dir", default="../paddleocr/PaddleOCR/train_data/cls/test/")
return parser.parse_args()
def main(args):
if not os.path.isdir(args.train_imgs_dir):
os.makedirs(args.train_imgs_dir)
if not os.path.isdir(args.val_imgs_dir):
os.makedirs(args.val_imgs_dir)
if not os.path.isdir(args.test_imgs_dir):
os.makedirs(args.test_imgs_dir)
split_label(args.all_label, args.train_label, args.val_label, args.test_label)
split_img(args.all_imgs_dir, args.train_label, args.train_imgs_dir, args.val_label, args.val_imgs_dir, args.test_label, args.test_imgs_dir)
if __name__ == "__main__":
main(get_args())