前言
predict_proba 函数是用于分类任务的 scikit-learn 中的决策树模型(如 CART - 分类与回归树)的一个方法。特别地,对于 DecisionTreeClassifier 类,predict_proba 用于估计给定输入数据的每个类的概率。
示例
predict_proba 返回一个数组,其中每一行对应于输入数据中的一个样本,每一列对应于一个类别,值是该样本属于该类别的概率。
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 加载数据集
data = load_iris()
X = data.data
y = data.target
# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建和训练决策树分类器
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
# 预测测试集样本属于各类别的概率
proba = clf.predict_proba(X_test)
# 结果示例
print("概率输出:\n", proba)
参数:
- X_test 是用于测试的输入数据。
- predict_proba(X_test) 返回一个二维数组,其中的每一行对应于 X_test中的一个样本,每一列对应于一个类别。每个元素表示该样本属于某一特定类别的概率。
- 对于多分类问题(如 Iris数据集),对于每个样本,返回的各类别的概率之和为 1。
用途:
- 概率输出: predict_proba 提供了比 predict方法更丰富的信息,因为它输出的是样本属于各类别的概率,而不是简单的类别标签。这对于了解模型的不确定性和 confidence(置信度)很有帮助。
- 阈值调整: 对于二分类问题,可以通过查看正类的预测概率,并选择一个适当的概率阈值来调整模型的敏感性(灵敏度)和特异性。
- 排序任务: 在某些应用中,如推荐系统,样本的排序依据其属于某类的概率高低。这时 predict_proba 比 predict
提供了更有用的信息。