初识机器学习 | 7.决策树

本文深入探讨了决策树的学习过程,包括信息熵和基尼系数的概念及其在节点划分中的作用。通过实例展示了如何利用信息熵和基尼系数优化决策树,并探讨了CART算法与其他超参数如max_depth、min_samples_split、min_samples_leaf和max_leaf_nodes的影响。最后,通过多个可视化例子展示了决策树的结构。
摘要由CSDN通过智能技术生成

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

%matplotlib
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
Using matplotlib backend: MacOSX
def plot_decision_boundary(model, axis):
    
    x0, x1 = np.meshgrid(
        np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1),
        np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1),
    )
    X_new = np.c_[x0.ravel(), x1.ravel()]

    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)

    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
    
    plt.contourf(x0, x1, zz, cmap=custom_cmap)

初探决策树

莺尾花数据集,当前选取后面两个维度。莺尾花分为以下三类。

from sklearn import datasets

iris = datasets.load_iris()
x = iris.data[:, 2:]
y = iris.target

plt.scatter(x[y==0,0], x[y==0,1])
plt.scatter(x[y==1,0], x[y==1,1])
plt.scatter(x[y==2,0], x[y==2,1])
plt.show()

在这里插入图片描述

from sklearn.tree import DecisionTreeClassifier

# 设置树的深度为2, 使用信息熵
dt_clf = DecisionTreeClassifier(max_depth=2, criterion="entropy", random_state=42)
tree = dt_clf.fit(x, y)

plot_decision_boundary(dt_clf, axis=[0.5, 7.5, 0, 3])
plt.scatter(x[y==0,0], x[y==0,1], label='calss_0')
plt.scatter(x[y==1,0], x[y==1,1], label='calss_1')
plt.scatter(x[y==2,0], x[y==2,1], label='calss_2')
plt.legend()
plt.show()

在这里插入图片描述

import graphviz 
from sklearn.tree import export_graphviz

dot_data = export_graphviz(dt_clf, filled=True, rounded=True, special_characters=True) 
graph = graphviz.Source(dot_data) 
graph 

在这里插入图片描述

上述决策过程, 关键的点:

  • 每个节点选择哪个维度做划分
  • 维度上划分的阀值多少

可通过做优化信息熵或基尼系数,搜索到该节点上的维度和对应的阀值。

信息熵

计算公式:
H = − ∑ i = 1 k p i log ⁡ ( p i ) H=-\sum_{i=1}^{k} p_{i} \log \left(p_{i}\right) H=i=1kpilog(pi)

二分类信息熵

H = − x log ⁡ ( x ) − ( 1 − x ) log ⁡ ( 1 − x ) H=-x \log (x)-(1-x) \log (1-x) H=xlog(x)(1x)log(1x)

eg: 计算以下场景的信息熵

场景一:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值