功能:统计每个类别的标签框数目,方便查看各类别数据量是否均衡。
import os
import io
import math
import sys
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple, OrderedDict
label_names = ['person','car','bus','truck']
def get_files(dir, suffix):
res = []
for root, directory, files in os.walk(dir):
#print("directory: ",directory)
for filename in files:
name, suf = os.path.splitext(filename)
if suf == suffix:
#res.append(filename)
res.append(os.path.join(root, filename))
return res
def convert_dataset(list_path, output_dir):
# 读取目录里面所有的 txt标记文件 列表
label_list = get_files(list_path, '.txt')
total_label_len = len(label_list)
class_datas = {}
class_count=[]
error_count = 0
person_nums=0
car_nums=0
bus_nums=0
truck_nums=0
# fd_label=open("label_num.txt",'a')
# fd_label.write(list_path+": ")
for i in range(0, total_label_len):
sys.stdout.write('\r>> Calculating {}/{} error{}'.format(
i + 1, total_label_len, error_count))
sys.stdout.flush()
# 单个Label txt文件读取
label_file = label_list[i]
file_name, type_name = os.path.splitext(label_file)
# _, img_name = os.path.split(file_name)
#image_path = os.path.join(file_name + '.jpg')
image_path = file_name + '.jpg'
if type_name != '.txt' or not os.path.exists(image_path):
error_count += 1
print("error_file: ",label_file.encode('UTF-8', 'ignore').decode('UTF-8'))
continue
#fd = open(os.path.join(list_path, label_file), encoding='utf8')
fd = open(label_file, 'r')
#fd = open(label_file, encoding='utf8')
lines = [line.split() for line in fd]
fd.close()
error_id = 0
for line in lines:
class_index = int(line[0])
if class_index >= 4:
error_id = 1
#print('\n error index: ', class_index, 'label_file', label_file)
continue
if class_index ==0:
person_nums +=1
elif class_index ==1:
car_nums +=1
elif class_index ==2:
bus_nums +=1
elif class_index ==3:
truck_nums +=1
class_count.append(person_nums)
class_count.append(car_nums)
class_count.append(bus_nums)
class_count.append(truck_nums)
total = person_nums + car_nums + bus_nums + truck_nums
# fd_label.write(str(person_nums)+', ')
# fd_label.write(str(car_nums)+', ')
# fd_label.write(str(bus_nums)+', ')
# fd_label.write(str(truck_nums))
# fd_label.write('\n')
# fd_label.close()
# if error_id:
# continue
# classes_text = label_names[class_index]
# if classes_text not in class_datas.keys():
# class_datas[classes_text] = 0
# class_datas[classes_text] += 1
print("image_path: ", image_path)
print("label_names: ",label_names)
print("class_count: ",class_count)
N =4
index = np.arange(N)
#plt.figure(figsize=(10, 10), dpi=80)
plt.bar(index, class_count, width=0.45, label="total: "+ str(total) , color="#8C8C00") #color="#87CEFA" #FFB5B5 #8C8C00
for a , b in zip(index,class_count):
plt.text(a, b + 0.05, '%.0f' % b, ha='center', va='bottom', fontsize=10)
plt.legend()
#plt.xlabel("type")
#plt.ylabel("count")
plt.xticks(index,('person','car','bus','truck'))
plt.title("train-data")
plt.savefig("train-data.png")
plt.show()
print('total_label_len', total_label_len)
def main():
#os.system('mkdir ' + FLAGS.output_path)
list_path = '/root/dhx/train'
output_dir = '/root/dhx/output'
convert_dataset(list_path, output_dir)
#output_path = os.path.join(os.getcwd(), FLAGS.output_path)
if __name__ == '__main__':
main()
结果如下: