In [1]: np.random.RandomState(0).randn(1,1)
Out[1]: array([[1.76405235]])
numpy.random.RandomState
class numpy.random.
RandomState
(seed=None)
Container for the Mersenne Twister pseudo-random number generator.
seed : {None, int, array_like}, optional
Random seed used to initialize the pseudo-random number generator. Can be any integer between 0 and 2**32 - 1 inclusive, an array (or other sequence) of such integers, or
None
(the default). If seed isNone
, then RandomState will try to read data from/dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
关于seed以及相关分布函数介绍请参考:numpy.random.RandomState()函数用法详解
真正例率(True Positive Rate, TPR),假正例率(False Positive Rate, FPR)的计算请参考:sklearn中绘制 ROC 曲线的函数 roc_curve() 解释
以下为Multi-class中One vs all的例子:(使用线性回归分类并展示ROC曲线)
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import RocCurveDisplay
import matplotlib.pyplot as plt
# 导入数据集
iris = load_iris()
target_names = iris.target_names
X, y = iris.data, iris.target
y = iris.target_names[y]
# 加入噪声
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
n_classes = len(np.unique(y))
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)
(
X_train,
X_test,
y_train,
y_test,
) = train_test_split(X, y, test_size=0.5, stratify=y, random_state=random_state)
classifier = LogisticRegression()
y_score = classifier.fit(X_train, y_train).predict_proba(X_test)
label_binarizer = LabelBinarizer().fit(y_train)
y_onehot_test = label_binarizer.transform(y_test)
y_onehot_test.shape # (n_samples, n_classes)
class_of_interest = "virginica"
class_id = np.flatnonzero(label_binarizer.classes_ == class_of_interest)[0] # index
RocCurveDisplay.from_predictions(
y_onehot_test[:, class_id], # 取onehot最后一位
y_score[:, class_id],
name=f"{class_of_interest} vs the rest",
color="darkorange",
plot_chance_level=True,
)
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("One-vs-Rest ROC curves:\nVirginica vs (Setosa & Versicolor)")
plt.legend()
plt.show()