统计txt标签格式中每个类别的数量。
# 计算txt标签格式每个类别的数量
import os
def get_every_class_num(txt_path):
# 需修改,根据自己的类别,注意一一对应
class_categories = ['class1', 'class2']
class_num = len(class_categories) # 样本类别数
class_list = [i for i in range(class_num)]
class_num_list = [0 for i in range(class_num)]
labels_list = os.listdir(txt_path)
for i in labels_list:
file_path = os.path.join(txt_path, i)
file = open(file_path, 'r') # 打开文件
file_data = file.readlines() # 读取所有行
for every_row in file_data:
class_val = every_row.split(' ')[0]
class_ind = class_list.index(int(class_val))
class_num_list[class_ind] += 1
file.close()
# 输出每一类的数量以及总数
result = dict(zip(class_categories, class_num_list))
for name, num in result.items():
print(name, ":", num)
print("-----------------------------------")
print('total:', sum(class_num_list))
if __name__ == '__main__':
# 需修改,txt文件所在路径【以下按照不同数据集进行统计】
train_txt_path = './coco128/train/labels'
print("训练集train的类别数如下:")
get_every_class_num(train_txt_path)
val_txt_path = './coco128/val/labels'
print("验证集val的类别数如下:")
get_every_class_num(txt_path)
test_txt_path = './coco128/test/labels'
print("测试集test的类别数如下:")
get_every_class_num(txt_path)