使用scikit-learn模拟机器学习正负样本不均衡时ROC曲线和PR曲线,可调整识别正确和错误的样本数量,概率使用随机值

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
# @Date : 2023/10/16 16:57
# @Author : HELIN
from sklearn import metrics
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# 正样本
p_num_correct = 80  # 识别正确的数量
p_num_error = 20  # 识别错误的数量

# 负样本
n_num_correct = 10000  # 识别正确的数量
n_num_error = 200  # 识别错误的数量

# 模拟识别正确的数据
p_label = [1] * p_num_correct
p_score = (np.random.randint(51, 100, p_num_correct) / 100).tolist()

n_label = [0] * n_num_correct
n_score = (np.random.randint(1, 49, n_num_correct) / 100).tolist()

# 模拟识别错误的数据
p_label_error = [1] * p_num_error
p_score_error = (np.random.randint(1, 49, p_num_error) / 100).tolist()

n_label_error = [0] * n_num_error
n_score_error = (np.random.randint(51, 100, n_num_error) / 100).tolist()

y_test = p_label + n_label + p_label_error + n_label_error
y_score = p_score + n_score + p_score_error + n_score_error
y_pred = np.round(y_score).tolist()
print(y_test)
print(y_score)
print(y_pred)

# 报告
report = metrics.classification_report(y_test, y_pred)
print(report)

# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)

# 计算ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_score)
roc_auc = auc(fpr, tpr)

# 创建热力图
plt.figure(figsize=(10, 8))
plt.subplot(2, 2, 1)
# 设置类别标签
class_names = ['0', '1']
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')

# 计算PR曲线
precision, recall, _ = precision_recall_curve(y_test, y_score)
average_precision = average_precision_score(y_test, y_score)

# 绘制ROC曲线
plt.subplot(2, 2, 3)
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")

# 绘制PR曲线
plt.subplot(2, 2, 4)
plt.plot(recall, precision, color='blue', lw=2, label='PR curve (area = %0.2f)' % average_precision)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")

plt.tight_layout()
plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值