独家 | 指南:不平衡分类的成本敏感决策树(附代码&链接)


作者:Jason Brownlee

翻译:陈超

校对:冯羽

本文约3500字,建议阅读10+分钟

本文介绍了不平衡分类中的成本敏感决策树算法。

决策树算法对平衡分类是有效的,但在不平衡数据集上却表现不佳。

 

决策树分裂点是为了能够在最小混淆的情况下将所有实例分成两组。当两个组别分别都由其中一个类别的实例占主导,那么用于选择分裂点设置的标准即为合理,而事实上,少数类中的实例将会被忽略。

 

通过修改评估分裂点的标准并将每一类别的重要性均纳入考虑,即可解决这一问题,通常指的是加权的分裂点或者加权的决策树。

在本指南中,你将看到的是不平衡分类的加权决策树。

在学习完本指南之后,你将会了解:

  • 标准决策树算法是怎样不支持不平衡分类的。

  • 当选择分裂点时,决策树算法如何通过类权值对模型误差进行加权。

  • 如何配置决策树算法中类的权值以及如何对不同的类权值配置进行网格化搜索。

SMOTE算法,单类别分类,成本敏感学习,阈值移动,以及更多其他内容,请检索我的新书,内含30个逐步教程以及完整的Python源代码。

新书链接:

https://machinelearningmastery.com/cost-sensitive-decision-trees-for-imbalanced-classification/

好的,我们开始。

如何对不平衡分类执行加权决策树

Photo by Bonnie Moreland, some rights reserved.

指南概观

本指南分为四部分,他们分别是:

一、不平衡分类数据集

二、不平衡分类决策树

三、Scikit-Learn中使用加权决策树

四、加权决策树的网格化搜索

一、不平衡分类数据集

在开始深入到不平衡分类的决策修正之前,我们先定义一个不平衡数据集。

我们可以使用 make_classification()函数来定义一个合成的不平衡两类别分类数据集。我们将生成10000个实例,其中少数类和多数类的比例为1:100。

make_classification()函数:

https://machinelearningmastery.com/cost-sensitive-decision-trees-for-imbalanced-classification/

...
# define dataset
X, y = make_classification(n_samples=10000, n_features=2, n_redundant=0,
  n_clusters_per_class=1, weights=[0.99], flip_y=0, random_state=3)

一旦生成之后,我们可以总结类的分布来验证生成数据集是我们所期望的。

...
# summarize class distribution
counter = Counter(y)
print(counter)

最后,我们可以创造一个实例的散点图并依据类标签进行着色,来帮助我们理解该数据集中实例分类所面临的挑战。

...
# scatter plot of examples by class label
for label, _ in counter.items():
  row_ix = where(y == label)[0]
  pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

将这些代码整合在一起,生成合成数据集和绘制实例的完整示例。

# Generate and plot a synthetic imbalanced classification dataset
from collections import Counter
from sklearn.datasets import make_classification
from matplotlib import pyplot
from numpy import where
# define dataset
X, y = make_classification(n_samples=10000, n_features=2, n_redundant=0,
  n_clusters_per_class=1, weights=[0.99], flip_y=0, random_state=3)
# summarize class distribution
counter = Counter(y)
print(counter)
# scatter plot of examples by class label
for label, _ in counter.items():
  row_ix = where(y == label)[0]
  pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

运行这个示例将会先创造一个数据集,然后汇总类的分布。

 

我们可以看到这个数据集接近1:100的类的分布,有比10000个稍微少一些的实例在多数类当中,100个实例在少数类中。

Counter({0: 9900, 1: 100})

接下来,是数据集的散点图

  • 3
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值