Python统计voc格式数据集类别分布并生成直方图

直接上代码:

'''
Author: TuZhou
Version: 1.0
Date: 2022-07-19 19:48:43
LastEditTime: 2022-07-20 10:54:21
LastEditors: TuZhou
Description: 
FilePath: \python_test\cul_class_num.py
'''
import glob
import xml.etree.ElementTree as ET
from matplotlib import pyplot as plt
import numpy as np

#xml文件前置路径
pre_path = "E:\\user\\dataSets\\UrbanObjectDetection_datasetTraffic\\Annotations\\"
#数据集分类标签
class_sign = ("bicycle","bus","car","motorbike","person","signal","light")
#数据集类别字典
class_list = {"bicycle":0,"bus":0,"car":0,"motorbike":0,"person":0,"trafficsignal":0,"trafficlight":0}
#需要读取的所有xml文件路径
path_list = []
#只有xml文件名的txt文件
xml_name_path = 'E:\\user\\dataSets\\UrbanObjectDetection_datasetTraffic\\ImageSets\\cleaned_dataset\\train.txt'

#读取只有xml文件名的txt文件,放入path_list中
def read_xml_path(xml_name_path):
    f = open(xml_name_path)
    while True:
        line = f.readline()  #包括换行符
        line = line[:-1]   #去掉换行符
        if line:
            line = pre_path + line + ".xml"
            path_list.append(line)
        else:
            break
    f.close()


#遍历所有xml文件并进行解析
def read_xml():
    for xml_file in path_list:
        # 返回解析树
        tree = ET.parse(xml_file)
        # 获取根节点
        root = tree.getroot()
        # 对所有目标进行解析
        for member in root.findall('object'):
        #获取object标签内的name
            objectname = member.find('name').text
            class_list[objectname] += 1

# #打印数据集各个类别的数量
# for item in class_list.items():
#     print(item)

#生成直方图
def gen_histogram(x, y, class_sign):
    fig, ax = plt.subplots()
    # 截尾平均数
    means = sum(sorted(y)[1:-1])/len(y[1:-1])
    b = ax.bar(x, y, width = 0.5, label='{}'.format(means))
    plt.title('Detection category distribution')
    for a, b in zip(x, y):
        ax.text(a, b+1, b, ha='center', va='bottom')

    plt.xticks( np.arange(1,len(class_sign)+1,1), class_sign)
    plt.ylim((1,max(y) + max(y)/10))
    plt.xticks(range(len(x)+2))
    plt.xlabel('Class')
    plt.ylabel('Number')
    plt.legend()
    # plt.savefig("类别分布.jpg", dpi=300,format="jpg")
    plt.show()


if __name__ == '__main__':
    read_xml_path(xml_name_path)
    read_xml()
    
    #y表示装入各个类别数量的列表
    y = []
    #打印数据集各个类别的数量
    for item in class_list.items():
        y.append(item[1])
        print(item)
    
    #x为类别数目
    x = range(1, len(class_list)+1)
    gen_histogram(x, y, class_sign)  

一般情况你只需要修改你的数据集分类以及xml路径即可,其中xml_name_path为只有xml文件名的txt文件路径,可酌情修改。

直方图示例:
在这里插入图片描述

  • 0
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值