将数据集image和mask同时划分成train、val、test

import os
import argparse
import random
import shutil
from shutil import copyfile

def rm_mkdir(dir_path):
    if os.path.exists(dir_path):
        shutil.rmtree(dir_path)
        print('Remove path - %s' % dir_path)
    os.makedirs(dir_path)
    print('Create path - %s' % dir_path)


def main(config):

    rm_mkdir(config.train_path)
    rm_mkdir(config.train_GT_path)
    rm_mkdir(config.valid_path)
    rm_mkdir(config.valid_GT_path)
    rm_mkdir(config.test_path)
    rm_mkdir(config.test_GT_path)

    filenames = os.listdir(config.origin_data_path)
    masknames = os.listdir(config.origin_GT_path)
    data_list = []
    GT_list = []
    
    for filename in filenames:
        data_list.append(filename)
    for maskname in masknames:
        GT_list.append(maskname)
    
    data_list = sorted(data_list)
    GT_list = sorted(GT_list)

    num_total = len(data_list)
    num_train = int((config.train_ratio/(config.train_ratio +
                    config.valid_ratio+config.test_ratio))*num_total)
    num_valid = int((config.valid_ratio/(config.train_ratio +
                    config.valid_ratio+config.test_ratio))*num_total)
    num_test = num_total - num_train - num_valid

    print('\nNum of train set : ', num_train)
    print('\nNum of valid set : ', num_valid)
    print('\nNum of test set : ', num_test)

    Arange = list(range(num_total))
    random.shuffle(Arange)

    for i in range(num_train):
        idx = Arange.pop()
        print(idx)
        src = os.path.join(config.origin_data_path, data_list[idx])
        dst = os.path.join(config.train_path, data_list[idx])
        copyfile(src, dst)
        print(src, dst)
        src = os.path.join(config.origin_GT_path, GT_list[idx])
        dst = os.path.join(config.train_GT_path, GT_list[idx])
        copyfile(src, dst)
        print(src, dst)

    for i in range(num_valid):
        idx = Arange.pop()

        src = os.path.join(config.origin_data_path, data_list[idx])
        dst = os.path.join(config.valid_path, data_list[idx])
        copyfile(src, dst)

        src = os.path.join(config.origin_GT_path, GT_list[idx])
        dst = os.path.join(config.valid_GT_path, GT_list[idx])
        copyfile(src, dst)

    for i in range(num_test):
        idx = Arange.pop()

        src = os.path.join(config.origin_data_path, data_list[idx])
        dst = os.path.join(config.test_path, data_list[idx])
        copyfile(src, dst)

        src = os.path.join(config.origin_GT_path, GT_list[idx])
        dst = os.path.join(config.test_GT_path, GT_list[idx])
        copyfile(src, dst)
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # model hyper-parameters
    parser.add_argument('--train_ratio', type=float, default=0.8)
    parser.add_argument('--valid_ratio', type=float, default=0.1)
    parser.add_argument('--test_ratio', type=float, default=0.1)

    # data path
    parser.add_argument('--origin_data_path', type=str,
                        default='/data/tanen/IID/images')
    parser.add_argument('--origin_GT_path', type=str,
                        default='/data/tanen/IID/masks')

    parser.add_argument('--train_path', type=str, default='/data/tanen/IID/train/')
    parser.add_argument('--train_GT_path', type=str,
                        default='/data/tanen/IID/train_mask/')
    parser.add_argument('--valid_path', type=str, default='/data/tanen/IID/val/')
    parser.add_argument('--valid_GT_path', type=str,
                        default='/data/tanen/IID/val_mask/')
    parser.add_argument('--test_path', type=str, default='/data/tanen/IID/test/')
    parser.add_argument('--test_GT_path', type=str,
                        default='/data/tanen/IID/test_mask/')

    config = parser.parse_args()
    print(config)
    main(config)

 只需要根据自己需求修改参数即可!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值