注意:单击此处
https://urlify.cn/rUva6f
下载完整的示例代码,或通过Binder在浏览器中运行此示例
当进行分类时,您不仅要预测类别标签,还要预测相关的概率。预测概率可以使您对预测分类标签具有信心,但是并非所有分类器都提供了经过良好校准的概率,有些分类器过于自信,而另一些分类器则不太自信,因此通常对预测概率进行单独的校准作为后期处理步骤。此示例说明了两种不同的校准方法,并使用Brier得分评估了返回的预测概率的质量(请参阅
https://en.wikipedia.org/wiki/Brier_score
)。
比较使用高斯朴素贝叶斯分类器(未校准,sigmoid校准和非参数等渗校准)的预测概率,可以观察到,只有非参数模型能够提供概率校准,该校准返回对于属于中间集群且带有各类标签的大多数样本的概率接近了预期的0.5,这样可以显著提高Brier得分。
输出:
![73d8ed7e8e74273b289eeacb9f8ab435.png](https://img-blog.csdnimg.cn/img_convert/73d8ed7e8e74273b289eeacb9f8ab435.png)
![1e62dbbd85d50eb192f182d981c84d84.png](https://img-blog.csdnimg.cn/img_convert/1e62dbbd85d50eb192f182d981c84d84.png)
Brier scores: (the smaller the better)
No calibration: 0.104
With isotonic calibration: 0.084
With sigmoid calibration: 0.109
print(__doc__)
# 作者: Mathieu Blondel
# Alexandre Gramfort
# Balazs Kegl
# Jan Hendrik Metzen
# 许可证: BSD Style.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from sklearn.datasets import make_blobs
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import brier_score_loss
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import train_test_split
n_samples = 50000
n_bins = 3 # 使用3个bins作为Calibration_curve,因为这里有3个聚类
# 生成2个类别的3个blob,其中第二个blob包含
# 一半正样本和一半负样本。
# 因此这这里的blob为概率0.5。
centers = [(-5, -5), (0, 0), (5, 5)]
X, y = make_blobs(n_samples=n_samples, centers=centers, shuffle=False,
random_state=42)
y[:n_samples // 2] = 0
y[n_samples // 2:] = 1
sample_weight = np.random.RandomState(42).rand(y.shape[0])
# 为了进行校准划分训练集和测试集
X_train, X_test, y_train, y_test, sw_train, sw_test = \
train_test_split(X, y, sample_weight, test_size=0.9, random_state=42)
# 无校准的高斯朴素贝叶斯
clf = GaussianNB()
clf.fit(X_train, y_train) # GaussianNB本身不支持样本权重
prob_pos_clf = clf.predict_proba(X_test)[:, 1]
# 等渗校准的高斯朴素贝叶斯
clf_isotonic = CalibratedClassifierCV(clf, cv=2, method='isotonic')
clf_isotonic.fit(X_train, y_train, sw_train)
prob_pos_isotonic = clf_isotonic.predict_proba(X_test)[:, 1]
# sigmoid校准的高斯朴素贝叶斯
clf_sigmoid = CalibratedClassifierCV(clf, cv=2, method='sigmoid')
clf_sigmoid.fit(X_train, y_train, sw_train)
prob_pos_sigmoid = clf_sigmoid.predict_proba(X_test)[:, 1]
print("Brier scores: (the smaller the better)")
clf_score = brier_score_loss(y_test, prob_pos_clf, sw_test)
print("No calibration: %1.3f" % clf_score)
clf_isotonic_score = brier_score_loss(y_test, prob_pos_isotonic, sw_test)
print("With isotonic calibration: %1.3f" % clf_isotonic_score)
clf_sigmoid_score = brier_score_loss(y_test, prob_pos_sigmoid, sw_test)
print("With sigmoid calibration: %1.3f" % clf_sigmoid_score)
# #############################################################################
# 绘制数据和预测概率的图
plt.figure()
y_unique = np.unique(y)
colors = cm.rainbow(np.linspace(0.0, 1.0, y_unique.size))
for this_y, color in zip(y_unique, colors):
this_X = X_train[y_train == this_y]
this_sw = sw_train[y_train == this_y]
plt.scatter(this_X[:, 0], this_X[:, 1], s=this_sw * 50,
c=color[np.newaxis, :],
alpha=0.5, edgecolor='k',
label="Class %s" % this_y)
plt.legend(loc="best")
plt.title("Data")
plt.figure()
order = np.lexsort((prob_pos_clf, ))
plt.plot(prob_pos_clf[order], 'r', label='No calibration (%1.3f)' % clf_score)
plt.plot(prob_pos_isotonic[order], 'g', linewidth=3,
label='Isotonic calibration (%1.3f)' % clf_isotonic_score)
plt.plot(prob_pos_sigmoid[order], 'b', linewidth=3,
label='Sigmoid calibration (%1.3f)' % clf_sigmoid_score)
plt.plot(np.linspace(0, y_test.size, 51)[1::2],
y_test[order].reshape(25, -1).mean(1),
'k', linewidth=3, label=r'Empirical')
plt.ylim([-0.05, 1.05])
plt.xlabel("Instances sorted according to predicted probability "
"(uncalibrated GNB)")
plt.ylabel("P(y=1)")
plt.legend(loc="upper left")
plt.title("Gaussian naive Bayes probabilities")
plt.show()
脚本的总运行时间:(0分钟0.617秒)
估计的内存使用量: 10 MB
下载Python源代码: plot_calibration.py
下载Jupyter notebook源代码: plot_calibration.ipynb
由Sphinx-Gallery生成的画廊
文壹由“伴编辑器”提供技术支持
☆☆☆为方便大家查阅,小编已将scikit-learn学习路线专栏 文章统一整理到公众号底部菜单栏,同步更新中,关注公众号,点击左下方“系列文章”,如图:欢迎大家和我一起沿着scikit-learn文档这条路线,一起巩固机器学习算法基础。(添加微信:mthler,备注:sklearn学习,一起进【sklearn机器学习进步群】开启打怪升级的学习之旅。)