python划分数据集并使各类别的数目相近

最近项目拿到了一个别人标注但没有划分的数据集,有13类,不过经过统计发现各类别的数目差距较大,最多的一类有五万多张图片,最少的一类只有两千多张,如果使用传统的划分方法,对所有的数据进行随机划分,将会导致样本严重不均衡的问题,甚至可能出现训练集中不存在某一类图片,因此考虑以最少的一类图片数目为基准,对每一类都选择两千张左右的图片,并且使用蓄水池算法保证选取的随机性,考虑到同一张图片中可能存在多个目标,并且目标也不一定是同类,因此对每一张图片的标注文件只参考其第一个标注的目标类别(如果标注文件中有没有标注的目标,需要先判断),最后对每一类图片按照数据集划分的比例随机划分到训练集、验证集、测试集中,虽然无法保证最终划分的数据集每一类图片数目非常相近,但大致差别不会太大,并且保证了训练集、验证集、测试集中每一类都会存在一定数目的图片。

import os
import xml.dom.minidom
import random

master_root = os.path.abspath(os.path.join(os.getcwd(), "../../"))
data_root = os.path.join(master_root, "name of your dataset")  # data_root = os.path.join(master_root, "coco")
ImageSets_path = os.path.join(data_root, "ImageSets/Main")
train_txt_path = os.path.join(ImageSets_path, "train.txt")
val_txt_path = os.path.join(ImageSets_path, "val.txt")
test_txt_path = os.path.join(ImageSets_path, "test.txt")
none_tag_path = os.path.join(ImageSets_path, "none_tag.txt")

xml_path = os.path.join(data_root, "Annotations/")

classes = ['classes of your dataset'] 
# classes = [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat']

files = os.listdir(xml_path)

# 蓄水池抽样算法
def add(list, size, len, file):
    if (len < size):
        list.append(file)
    else:
        i = random.randint(0, len)
        if (i < size):
            list[i] = file

def create_imagesets_train_val_test(lists, traintxt_full_path, valtxt_full_path, testtxt_full_path):
    # 训练集比例
    train_percent = 0.6
    # 验证集比例
    val_percent = 0.2
    # 测试集比例
    test_percent = 0.2

    ftrain = open(traintxt_full_path, 'w')
    fval = open(valtxt_full_path, 'w')
    ftest = open(testtxt_full_path, 'w')

    trainList = []
    valList = []
    testList = []

    for list in lists:
        num = len(list)

        num_train = int(num * train_percent)  # 训练集个数

        num_val = int(num * val_percent)  # 验证集个数
        # 随机选num_train个train文件
        train_list = random.sample(list, num_train)
        for i in train_list:
            trainList.append(i)
            list.remove(i)
        val_list = random.sample(list, num_val)
        for j in val_list:
            valList.append(j)
            list.remove(j)
        test_list = list
        for k in test_list:
            testList.append(k)

    trainList.sort()
    valList.sort()
    testList.sort()

    for i in trainList:
        ftrain.write(i)  # train.txt文件写入
    for j in valList:
        fval.write(j)  # val.txt文件写入
    for k in testList:
        ftest.write(k)  # test.txt文件写入

    ftrain.close()  # 关闭train.txt
    fval.close()  # 关闭val.txt
    ftest.close()  # 关闭test.txt


if __name__ == '__main__':

    lists = [[] for i in range(len(classes))]
    sizes = []
    length = [0 for i in range(len(classes))]
    for i in range(len(classes)):
        sizes.append(random.randint(1950, 2250))  # 大概数目

    # 记录没标注的图片
    none_tag = []
    none = open(none_tag_path, 'w')

    # 遍历所有标注文件
    for file in files:
            xmlfile = xml_path + file
            dom = xml.dom.minidom.parse(xmlfile)  # 读取xml文档
            root = dom.documentElement  # 得到文档元素对象
            objectlist = root.getElementsByTagName("object")
            if len(objectlist) == 0:
                none_tag.append(os.path.splitext(file)[0] + '\n')
            else:
                # 如果有标注就按第一个标注的对象分类
                object = objectlist[0]
                namelist = object.getElementsByTagName("name")
                objectname = namelist[0].childNodes[0].data
                if objectname in classes:
                    cls_id = classes.index(objectname)
                    add(lists[cls_id], sizes[cls_id], length[cls_id], os.path.splitext(file)[0] + '\n')  # 使用蓄水池算法实现随机选取样本
                    length[cls_id] += 1

    for n in none_tag:
        none.write(n)  # none_tag.txt文件写入

    none.close()  # 关闭none_tag.txt
    create_imagesets_train_val_test(lists, train_txt_path, val_txt_path, test_txt_path)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值