python实现召回率、精度、f1代码

import matplotlib.pyplot as plt
import numpy as np

1、统计训练和测试精度

# 读取实验结果中的精度和损失
def data_plot(path):
    # path = r'./model_png\5gcn_128_node_pos_floor_2000.log'
    # path = r'./[6, 6]_Standard_taz.csv'
    with open(path, mode="r", encoding="utf-8") as f:
        data = f.readlines()
    Train = []
    Test = []
    for item in data:
        if "Train Acc" in item:
            # print(item)
            train = item.split(":")[1].strip()
            if train == "":
                continue
            Train.append(float(train))

        elif "Test Acc" in item:
            test = item.split(":")[1].strip()
            if test == "":
                continue
            Test.append(float(test))
    plt.title('train')
    plt.plot(Train)
    # plt.savefig("./model_png/5gcn_128_node_pos_floor_2000_acc.png")
    # plt.show()
    plt.clf()  # 重置画布
    plt.title('test')
    plt.plot(Test)
    # plt.savefig("./model_png/5gcn_128_node_pos_floor_2000_loss.png")
    # plt.show()

if __name__ == '__main__':
    # 1、绘制损失精度图
    path = r"./[6, 6, 6, 6]_Standard_taz.csv"
    data_plot(path)

2、计算混淆矩阵的精度、召回率和f1

def calculate_prediction(metrix):
    """
    计算精度
    """
    label_pre = []
    current_sum = 0
    for i in range(metrix.shape[0]):
        current_sum += metrix[i][i]
        label_total_sum = metrix.sum(axis=0)[i]
        pre = round(100 * metrix[i][i] / label_total_sum, 4)
        label_pre.append(pre)
    print("每类精度:", label_pre)
    all_pre = round(100 * current_sum / metrix.sum(), 4)
    print("总精度:", all_pre)
    return label_pre, all_pre


def calculate_recall(metrix):
    """
    先计算某一个类标的召回率;
    再计算出总体召回率
    """
    label_recall = []
    for i in range(metrix.shape[0]):
        label_total_sum = metrix.sum(axis=1)[i]
        label_correct_sum = metrix[i][i]
        recall = 0
        if label_total_sum != 0:
            recall = round(100 * float(label_correct_sum) / float(label_total_sum), 4)

        label_recall.append(recall)
    print("每类召回率:", label_recall)
    all_recall = round(np.array(label_recall).sum() / metrix.shape[0], 4)
    print("总召回率:", all_recall)
    return label_recall, all_recall

def calculate_f1(prediction, all_pre, recall, all_recall):
    """
    计算f1分数
    """
    all_f1 = []
    for i in range(len(prediction)):
        pre, reca = prediction[i], recall[i]
        f1 = 0
        if (pre + reca) != 0:
            f1 = round(2 * pre * reca / (pre + reca), 4)

        all_f1.append(f1)
    print("每类f1:", all_f1)
    print("总的f1:", round(2 * all_pre * all_recall / (all_pre + all_recall), 4))
    return all_f1
if __name__ == '__main__':
    # ************************************
    # 2、计算召回率、精度、f1
    metrix = \
        np.array([[84, 30, 16, 4, 4],
                  [11, 88, 14, 5, 1],
                  [13, 31, 75, 0, 0],
                  [12, 15, 3, 71, 1],
                  [31, 7, 5, 12, 67]])
    # metrix = np.array()
    # print(metrix,metrix.shape[0])
    print(metrix.sum(axis=0)[0], metrix.sum(axis=1)[0])
    # print("召回率:", calculate_recall(metrix))
    # print("精度:", calculate_prediction(metrix))
    label_pre, all_pre = calculate_prediction(metrix)
    label_recall, all_recall = calculate_recall(metrix)
    # ************************************
    calculate_f1(label_pre, all_pre, label_recall, all_recall)

3、绘制混淆矩阵展示图形,已经混淆矩阵平均值

def get_Confusion_matrix(path):
    numCount = 200  # 统计后多少数据,例如后200个,直接写200
    with open(path, mode="r", encoding="utf-8") as f:
        data = f.readlines()
    Confusion = []
    epoch_con = []
    for item in data:
        if ("Train" in item) or ("Test" in item):
            continue
        if "[[" in item:
            epoch_con = []
            # data = (np.array(item.strip()[2:])).tolist()
            datas = list((item.strip()[2:-1]).split())
            epoch_con.append(datas)
            continue
        if "]]" in item:
            datas = list((item.strip()[1:-2]).split())
            epoch_con.append(datas)
            Confusion.append(epoch_con)
            continue
        if "[" in item:
            datas = list((item.strip()[1:-1]).split())
            epoch_con.append(datas)

    sum = np.zeros((5, 5), dtype=int)
    # numCount = 200
    for temp in Confusion[-numCount:]:
        print(temp)
        sum += np.array(temp, dtype=int)

        metrix = sum / numCount
    print(metrix)
    # 绘制混淆矩阵
    plot_Confusion_matrix(metrix=metrix)

    print(metrix.sum(axis=0)[0], metrix.sum(axis=1)[0])
    # print("召回率:", calculate_recall(metrix))
    # print("精度:", calculate_prediction(metrix))
    label_pre, all_pre = calculate_prediction(metrix)
    label_recall, all_recall = calculate_recall(metrix)
    # ************************************
    calculate_f1(label_pre, all_pre, label_recall, all_recall)

if __name__ == '__main__':
	path = r"C:\Users\Administrator\Desktop\nj单一特征对比实验\12-16号graphsage的结果" \
           r"\lstm\[7, 7, 7, 7]_Standard_taz_lstm.csv"
    get_Confusion_matrix(path)
    

  • 1
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值