import pandas as pd
import numpy as np
def classifaction_report_csv(report):
report_data = []
lines = report.split('\n')
for line in lines[2:-3]:
row = {}
row_data = line.split(' ')
row['class'] = row_data[0]
row['precision'] = float(row_data[1])
row['recall'] = float(row_data[2])
row['f1_score'] = float(row_data[3])
row['support'] = float(row_data[4])
report_data.append(row)
df = pd.DataFrame.from_dict(report_data)
return df
df1 = classifaction_report_csv(res)
df1 = df1.drop(index=[0])
df_np = df1['f1_score'].values
d = {}
for i in range(20):
j = i+1
list1 = list(reversed(df_np[i*6:(i+1)*6].tolist()))
d[j] = list1
df = pd.DataFrame(d,index = [6,5,4,3,2,1])
df = df.reindex(sorted(df.columns), axis=1)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(20,6))
### annot是表示显示方块代表的数值出来 cmap颜色
sns.heatmap(df,annot = True,cmap="YlGnBu")
plt.show()
将classifaction_report保存成dataframe画图
最新推荐文章于 2024-08-23 15:03:47 发布