数据集分布情况统计
![在这里插入图片描述](https://img-blog.csdnimg.cn/3a8f6341083c40d88f09bab1e8038fbb.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBAMTIyNmtt,size_20,color_FFFFFF,t_70,g_se,x_16)
import os
import numpy as np
import collections
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('TkAgg')
base_path = 'data/'
labelsTrain = []
ssT = []
for line in open(base_path + 'train1.txt'):
line_split = line.strip().split(',')
if len(line_split) != 2:
print('%s contain error lable' % os.path.basename(base_path + 'train1.txt'))
label = int(line_split[1])
labelsTrain.append(label)
categoryTrain = collections.Counter(labelsTrain)
categoryTrainNum = len(categoryTrain)
for i in range(0, categoryTrainNum):
s = categoryTrain[i]
ssT.append(s)
x = np.arange(categoryTrainNum).astype(dtype=np.str)
plt.figure(figsize=(15, 10))
plt.bar(x, ssT, align='center')
plt.xlabel('CategoryTrain')
plt.ylabel('Number')
plt.title('TrainData Distribution')
for x, y in enumerate(ssT):
plt.text(x, y + 10, '%s' % y, ha='center')
plt.savefig('data/Distribution/TrainDataDistribution.jpg')
plt.show()