分类任务中分析数据分布,分析样本是否均衡等。以下为例,目录如下:
train
1
2
3
...
n
train目录下,包含n个类,每个类目录下,有很多图片。我们的目的是统计目录下的数目,然后画直方图:
'''
@author: mengting gu
@contact: 1065504814@qq.com
@time: 2021/4/26 下午3:40
@file: data_distribute.py
@desc:
'''
import os
import matplotlib.pyplot as plt
def file_list_data_distribute(filelist):
"""
Parameters
----------
filelist : train | id1, id2, id3
Returns : [len(id1), len(id2), len(id3)]
-------
"""
d = {}
for i, files in enumerate(os.listdir(filelist)):
num_files_rec = 0
jpg_path = os.path.join(filelist, files)
print("jpg_path : {}".format(jpg_path))
for file in os.listdir(jpg_path):
if file[-4:] == ".jpg":
num_files_rec += 1
# print("file : {}".format(file))
d[str(i)] = num_files_rec
print(d)
return d
if __name__ == '__main__':
file_path = "./train/"
d = file_list_data_distribute(filelist=file_path)
plt.bar(d.keys(), d.values())
plt.show()
# plt.hist(d.values(), bins=20, edgecolor='k', alpha=0.7) # 设置直方边线颜色为黑色,不透明度为 0.35
# plt.show()