PA100k数据集写txt
import json
import os
# 文件
def txt2list(filename):
f = open(filename)
lines = f.readlines()
file_name = []
for index, line in enumerate(lines):
if index == 0:
continue
file_name.append(line.split(".")[0].split("'")[1]) # 得到只有文件名的列表
f.close()
return file_name
def label2list(filename):
f = open(filename)
lines = f.readlines()
ll = []
for index, line in enumerate(lines):
if index == 0:
continue
ll.append(line.strip('\n'))
f.close()
return ll
def writetxt(path,txt):
with open(path,"w+") as f:
f.write(txt)
if __name__=="__main__":
test_img= "test_images_name.txt"
test_img_list=txt2list(test_img)
test_label="test_label.txt"
test_label_list=label2list(test_label)
path="./test_label"
if os.path.exists(path) is False:
os.mkdir(path)
for i,j in zip(test_img_list,test_label_list):
i_name=i+".txt"
name=os.path.join(path,i_name)
txt=i+".jpg"+" "+j.replace(","," ")
writetxt(name,txt)
分析
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'SimHei'
def getTxtAndNum(label_dir,class_num):
assert Path(label_dir).is_dir(), "label_dir is not exist"
txts = os.listdir(label_dir)
class_list = [0] * class_num # 初始化[0, 0, 0] 存放所有txt文件中每一个类别的数量
for txt in txts: # 遍历每一个txt
with open(os.path.join(label_dir, txt), 'r') as f: # 打开当前txt文件 读取每一行信息 存放在lines中
lines = f.readlines()
for line in lines: # line = [name 0 1 1 1]
label = (line.split()[1:]) # 得到当前行的class信息
for la in range(len(label)):
if int(label[la])==1:
class_list[la]+=1
data = np.array(class_list)
return data
def draw_class_distribution(data):
x=["0:Female","1:AgeOve r60","2:Age18-60","3:AgeLess18","4:Front","5:Side","6:Back","7:Hat","8:Glasses","9:HandBag","10:ShoulderBag","11:Backpack","12:HoLd0bjectsInFront" ,"13:ShortSLeeve" ,"14:LongSleeve","15:UpperStride","16:UpperLogo","17:UpperPlaid","18:UpperSplice" ,"19:LowerStripe ","20:LowerPattern","21:LongCoat","22:Trousers","23:Shorts","24:Skirt&Dress","25:boots"]
print(len(x))
fig = plt.figure(figsize=(40, 10)) # 画布大小和像素密度
plt.bar(x, data, width=0.5, align="center")#柱状图
# plt.plot(x, data)#折线图
for a, b, i in zip(x, data, range(len(x))): # zip 函数
plt.text(a, b + 0.01, "%d" % int(data[i]), ha='center', fontsize=15, color="r") # plt.text 函数
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.xlabel('类别', fontsize=16)
plt.ylabel('数量', fontsize=16)
plt.title('PA100K测试集分布', fontsize=16)
plt.show()
# 保存到本地
# plt.savefig("")
if __name__ == '__main__':
label_dir = "./test_label"
data = getTxtAndNum(label_dir, 26)
print(data)
draw_class_distribution(data)