mmclassification训练数据准备脚本

1.随机划分训练、验证、测试

import os
import glob
import random
import shutil

dataset_dir = './XXX_classification/'
train_dir = './datasets/train/'
valid_dir = './datasets/val/'
test_dir = './datasets/test/'

train_per = 0.8
valid_per = 0.1
test_per = 0.1


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    for root, dirs, files in os.walk(dataset_dir):
        for sDir in dirs:
            imgs_list = glob.glob(os.path.join(root, sDir)+'/*.jpg')
            random.seed(666)
            random.shuffle(imgs_list)
            imgs_num = len(imgs_list)

            train_point = int(imgs_num * train_per)
            valid_point = int(imgs_num * (train_per + valid_per))

            for i in range(imgs_num):
                if i < train_point:
                    out_dir = train_dir + sDir + '/'
                elif i < valid_point:
                    out_dir = valid_dir + sDir + '/'
                else:
                    out_dir = test_dir + sDir + '/'

                makedir(out_dir)
                out_path = out_dir + os.path.split(imgs_list[i])[-1]
                shutil.copy(imgs_list[i], out_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))

2.将划分好的训练、验证、测试生成meta格式

import os
from glob import glob
from pathlib import Path


def generate_mmcls_ann(data_dir, img_type='.jpg'):
    data_dir = str(Path(data_dir)) + '/'
    classes = ['0000', '0001', '0002', '0003']
    class2id = dict(zip(classes, range(len(classes))))
    data_dir = str(Path(data_dir)) + '/'
    dir_types = ['train', 'val', 'test']

    sub_dirs = os.listdir(data_dir)
    ann_dir = data_dir + 'meta/'
    if not os.path.exists(ann_dir):
        os.makedirs(ann_dir)
    for sd in sub_dirs:
        if sd not in dir_types:
            continue
        annotations = []
        target_dir = data_dir + sd + '/'
        for d in os.listdir(target_dir):
            class_id = str(class2id[d])
            images = glob(target_dir + d + '/*' + img_type)
            for img in images:
                img = d + '/' + os.path.basename(img)
                annotations.append(img + ' ' + class_id + '\n')
        annotations[-1] = annotations[-1].strip()
        with open(ann_dir + sd + '.txt', 'w') as f:
            f.writelines(annotations)


if __name__ == '__main__':
    data_dir = './datasets/'
    generate_mmcls_ann(data_dir)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

VisionX Lab

你的鼓励将是我更新的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值