从roc_curve中常用的np.random.RandomState(0)到OvR

文章通过numpy的RandomState类设置随机种子,展示了如何在多分类问题中使用One-vs-all策略。具体例子中,利用LogisticRegression进行线性回归分类,并计算了ROC曲线来评估模型性能,特别是对virginica类别的识别。代码包括数据集加载、添加噪声、训练测试集划分以及ROC曲线的绘制。
摘要由CSDN通过智能技术生成
In [1]: np.random.RandomState(0).randn(1,1)
Out[1]: array([[1.76405235]])

numpy.random.RandomState

class numpy.random.RandomState(seed=None)

Container for the Mersenne Twister pseudo-random number generator.

seed : {None, int, array_like}, optional

Random seed used to initialize the pseudo-random number generator. Can be any integer between 0 and 2**32 - 1 inclusive, an array (or other sequence) of such integers, or None (the default). If seed is None, then RandomState will try to read data from /dev/urandom (or the Windows analogue) if available or seed from the clock otherwise.

关于seed以及相关分布函数介绍请参考:numpy.random.RandomState()函数用法详解 

真正例率(True Positive Rate, TPR),假正例率(False Positive Rate, FPR)的计算请参考:sklearn中绘制 ROC 曲线的函数 roc_curve() 解释

以下为Multi-class中One vs all的例子:(使用线性回归分类并展示ROC曲线)

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import RocCurveDisplay
import matplotlib.pyplot as plt

# 导入数据集
iris = load_iris()
target_names = iris.target_names
X, y = iris.data, iris.target
y = iris.target_names[y]

# 加入噪声
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
n_classes = len(np.unique(y))
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)
(
    X_train,
    X_test,
    y_train,
    y_test,
) = train_test_split(X, y, test_size=0.5, stratify=y, random_state=random_state)

classifier = LogisticRegression()
y_score = classifier.fit(X_train, y_train).predict_proba(X_test)

label_binarizer = LabelBinarizer().fit(y_train)
y_onehot_test = label_binarizer.transform(y_test)
y_onehot_test.shape  # (n_samples, n_classes)
class_of_interest = "virginica"
class_id = np.flatnonzero(label_binarizer.classes_ == class_of_interest)[0]    # index

RocCurveDisplay.from_predictions(
    y_onehot_test[:, class_id],    # 取onehot最后一位
    y_score[:, class_id],
    name=f"{class_of_interest} vs the rest",
    color="darkorange",
    plot_chance_level=True,
)
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("One-vs-Rest ROC curves:\nVirginica vs (Setosa & Versicolor)")
plt.legend()
plt.show()

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值