import matplotlib. pyplot as plt
import numpy as np
1、统计训练和测试精度
def data_plot ( path) :
with open ( path, mode= "r" , encoding= "utf-8" ) as f:
data = f. readlines( )
Train = [ ]
Test = [ ]
for item in data:
if "Train Acc" in item:
train = item. split( ":" ) [ 1 ] . strip( )
if train == "" :
continue
Train. append( float ( train) )
elif "Test Acc" in item:
test = item. split( ":" ) [ 1 ] . strip( )
if test == "" :
continue
Test. append( float ( test) )
plt. title( 'train' )
plt. plot( Train)
plt. clf( )
plt. title( 'test' )
plt. plot( Test)
if __name__ == '__main__' :
path = r"./[6, 6, 6, 6]_Standard_taz.csv"
data_plot( path)
2、计算混淆矩阵的精度、召回率和f1
def calculate_prediction ( metrix) :
"""
计算精度
"""
label_pre = [ ]
current_sum = 0
for i in range ( metrix. shape[ 0 ] ) :
current_sum += metrix[ i] [ i]
label_total_sum = metrix. sum ( axis= 0 ) [ i]
pre = round ( 100 * metrix[ i] [ i] / label_total_sum, 4 )
label_pre. append( pre)
print ( "每类精度:" , label_pre)
all_pre = round ( 100 * current_sum / metrix. sum ( ) , 4 )
print ( "总精度:" , all_pre)
return label_pre, all_pre
def calculate_recall ( metrix) :
"""
先计算某一个类标的召回率;
再计算出总体召回率
"""
label_recall = [ ]
for i in range ( metrix. shape[ 0 ] ) :
label_total_sum = metrix. sum ( axis= 1 ) [ i]
label_correct_sum = metrix[ i] [ i]
recall = 0
if label_total_sum != 0 :
recall = round ( 100 * float ( label_correct_sum) / float ( label_total_sum) , 4 )
label_recall. append( recall)
print ( "每类召回率:" , label_recall)
all_recall = round ( np. array( label_recall) . sum ( ) / metrix. shape[ 0 ] , 4 )
print ( "总召回率:" , all_recall)
return label_recall, all_recall
def calculate_f1 ( prediction, all_pre, recall, all_recall) :
"""
计算f1分数
"""
all_f1 = [ ]
for i in range ( len ( prediction) ) :
pre, reca = prediction[ i] , recall[ i]
f1 = 0
if ( pre + reca) != 0 :
f1 = round ( 2 * pre * reca / ( pre + reca) , 4 )
all_f1. append( f1)
print ( "每类f1:" , all_f1)
print ( "总的f1:" , round ( 2 * all_pre * all_recall / ( all_pre + all_recall) , 4 ) )
return all_f1
if __name__ == '__main__' :
metrix = \
np. array( [ [ 84 , 30 , 16 , 4 , 4 ] ,
[ 11 , 88 , 14 , 5 , 1 ] ,
[ 13 , 31 , 75 , 0 , 0 ] ,
[ 12 , 15 , 3 , 71 , 1 ] ,
[ 31 , 7 , 5 , 12 , 67 ] ] )
print ( metrix. sum ( axis= 0 ) [ 0 ] , metrix. sum ( axis= 1 ) [ 0 ] )
label_pre, all_pre = calculate_prediction( metrix)
label_recall, all_recall = calculate_recall( metrix)
calculate_f1( label_pre, all_pre, label_recall, all_recall)
3、绘制混淆矩阵展示图形,已经混淆矩阵平均值
def get_Confusion_matrix ( path) :
numCount = 200
with open ( path, mode= "r" , encoding= "utf-8" ) as f:
data = f. readlines( )
Confusion = [ ]
epoch_con = [ ]
for item in data:
if ( "Train" in item) or ( "Test" in item) :
continue
if "[[" in item:
epoch_con = [ ]
datas = list ( ( item. strip( ) [ 2 : - 1 ] ) . split( ) )
epoch_con. append( datas)
continue
if "]]" in item:
datas = list ( ( item. strip( ) [ 1 : - 2 ] ) . split( ) )
epoch_con. append( datas)
Confusion. append( epoch_con)
continue
if "[" in item:
datas = list ( ( item. strip( ) [ 1 : - 1 ] ) . split( ) )
epoch_con. append( datas)
sum = np. zeros( ( 5 , 5 ) , dtype= int )
for temp in Confusion[ - numCount: ] :
print ( temp)
sum += np. array( temp, dtype= int )
metrix = sum / numCount
print ( metrix)
plot_Confusion_matrix( metrix= metrix)
print ( metrix. sum ( axis= 0 ) [ 0 ] , metrix. sum ( axis= 1 ) [ 0 ] )
label_pre, all_pre = calculate_prediction( metrix)
label_recall, all_recall = calculate_recall( metrix)
calculate_f1( label_pre, all_pre, label_recall, all_recall)
if __name__ == '__main__' :
path = r"C:\Users\Administrator\Desktop\nj单一特征对比实验\12-16号graphsage的结果" \
r"\lstm\[7, 7, 7, 7]_Standard_taz_lstm.csv"
get_Confusion_matrix( path)