统计每个类别的标签框数目

功能:统计每个类别的标签框数目,方便查看各类别数据量是否均衡。

import os  
import io  
import math
import sys
import numpy as np

import matplotlib.pyplot as plt
from collections import namedtuple, OrderedDict  

label_names = ['person','car','bus','truck']

def get_files(dir, suffix): 
    res = []
    for root, directory, files in os.walk(dir): 
        #print("directory: ",directory)
        for filename in files:
            name, suf = os.path.splitext(filename) 
            if suf == suffix:
                #res.append(filename)
                res.append(os.path.join(root, filename))
    return res
def convert_dataset(list_path, output_dir):

    # 读取目录里面所有的 txt标记文件 列表
    label_list = get_files(list_path, '.txt')
    total_label_len = len(label_list)
    class_datas = {}
    class_count=[]
    error_count = 0
    person_nums=0
    car_nums=0
    bus_nums=0
    truck_nums=0

    # fd_label=open("label_num.txt",'a')
    # fd_label.write(list_path+":  ")

    for i in range(0, total_label_len):
        sys.stdout.write('\r>> Calculating {}/{} error{}'.format(
            i + 1, total_label_len, error_count))
        sys.stdout.flush()
        # 单个Label txt文件读取
        label_file = label_list[i]
        file_name, type_name = os.path.splitext(label_file)
        # _, img_name = os.path.split(file_name)
        #image_path = os.path.join(file_name + '.jpg')
        image_path = file_name + '.jpg'
        if type_name != '.txt' or not os.path.exists(image_path):
            error_count += 1
            print("error_file: ",label_file.encode('UTF-8', 'ignore').decode('UTF-8'))
            continue

        #fd = open(os.path.join(list_path, label_file), encoding='utf8')
        fd = open(label_file, 'r')
        #fd = open(label_file, encoding='utf8')
        lines = [line.split() for line in fd]
        fd.close()
        error_id = 0
        for line in lines:  
            class_index = int(line[0])
            if class_index >= 4:
                error_id = 1
                #print('\n error index: ', class_index, 'label_file', label_file)
                continue
            if class_index ==0:
                person_nums +=1
              
            elif class_index ==1:
                car_nums +=1
               
            elif class_index ==2:
                bus_nums +=1
               
            elif class_index ==3:
                truck_nums +=1

    class_count.append(person_nums)
    class_count.append(car_nums)  
    class_count.append(bus_nums)
    class_count.append(truck_nums)
    total = person_nums + car_nums + bus_nums + truck_nums
    # fd_label.write(str(person_nums)+', ')
    # fd_label.write(str(car_nums)+', ')
    # fd_label.write(str(bus_nums)+', ')
    # fd_label.write(str(truck_nums))
    # fd_label.write('\n')
    # fd_label.close()
        # if error_id:
        #     continue

            # classes_text = label_names[class_index]
            # if classes_text not in class_datas.keys():
            #     class_datas[classes_text] = 0
            # class_datas[classes_text] += 1
    print("image_path: ", image_path)
    print("label_names: ",label_names)
    print("class_count: ",class_count)
    N =4
    index = np.arange(N)
    #plt.figure(figsize=(10, 10), dpi=80)
    plt.bar(index, class_count, width=0.45, label="total: "+ str(total) , color="#8C8C00")  #color="#87CEFA" #FFB5B5  #8C8C00

    for a , b in zip(index,class_count):
        plt.text(a, b + 0.05, '%.0f' % b, ha='center', va='bottom', fontsize=10)
    plt.legend()
    #plt.xlabel("type")
    #plt.ylabel("count")
    plt.xticks(index,('person','car','bus','truck'))
    plt.title("train-data")
    plt.savefig("train-data.png")
    plt.show()
    print('total_label_len', total_label_len)


def main():  
    #os.system('mkdir ' +  FLAGS.output_path)
    list_path = '/root/dhx/train'
    output_dir = '/root/dhx/output'
    convert_dataset(list_path, output_dir)
    #output_path = os.path.join(os.getcwd(), FLAGS.output_path)  
  
  
if __name__ == '__main__':  
    main()

结果如下:
在这里插入图片描述

import numpy as np import pandas as pd from sklearn.cluster import KMeans from sklearn.preprocessing import StandardScaler from scipy.spatial.distance import cdist import matplotlib.pyplot as plt from pandas import DataFrame from sklearn.decomposition import PCA plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签 plt.rcParams['axes.unicode_minus']=False #用来正常显示负号 pd.set_option('display.max_rows', None)#显示全部行 pd.set_option('display.max_columns', None)#显示全部列 np.set_printoptions(threshold=np.inf) pd.set_option('display.max_columns', 9000) pd.set_option('display.width', 9000) pd.set_option('display.max_colwidth', 9000) df = pd.read_csv(r'附件1.csv',encoding='gbk') X = np.array(df.iloc[:, 1:]) X=X[0:,1:] k=93 kmeans_model = KMeans(n_clusters=k, random_state=123) fit_kmeans = kmeans_model.fit(X) # 模型训练 #查看聚类结果 kmeans_cc = kmeans_model.cluster_centers_ # 聚类中心 print('各类聚类中心为:\n', kmeans_cc) kmeans_labels = kmeans_model.labels_ # 样本的类别标签 print('各样本的类别标签为:\n', kmeans_labels) r1 = pd.Series(kmeans_model.labels_).value_counts() # 统计不同类别样本的数目 print('最终每个类别数目为:\n', r1) # 输出聚类分群的结果 # cluster_center = pd.DataFrame(kmeans_model.cluster_centers_, # columns=[ str(x) for x in range(1,94)]) # 将聚类中心放在数据中 # cluster_center.index = pd.DataFrame(kmeans_model.labels_). \ # drop_duplicates().iloc[:, 0] # 将样本类别作为数据索引 # print(cluster_center)代码解释
06-13
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值