机器学习入门之决策树

介绍

决策树是一种常见的机器学习算法,它的思想十分朴素,类似于我们平时利用选择做决策的过程。它是类似流程图的结构,其中每个内部节点表示一个测试功能,即类似做出决策的过程(动作),每个叶节点都表示一个类标签,即在计算所有特征之后做出的决定(结果)。标签和分支表示导致这些类标签的功能的连接。从根到叶的路径表示分类规则。比如下面这个“相亲决策树”:

决策树的构建

决策树通常有三个步骤:
特征选择
决策树的生成
决策树的修剪
决策树学习的算法通常是一个递归地选择最优特征,并根据该特征对训练数据进行分割,使得对各个子数据集有一个最好的分类的过程。这一过程对应着对特征空间的划分,也对应着决策树的构建。
这一过程对应着对特征空间的划分,也对应着决策树的构建。
开始:构建根节点,将所有训练数据都放在根节点,选择一个最优特征,按照这一特征将训练数据集分割成子集,使得各个子集有一个在当前条件下最好的分类。
如果这些子集已经能够被基本正确分类,那么构建叶节点,并将这些子集分到所对应的叶子节点去。
如果还有子集不能够被正确的分类,那么就对这些子集选择新的最优特征,继续对其进行分割,构建相应的节点,如此递归进行,直至所有训练数据子集被基本正确的分类,或者没有合适的特征为止。
每个子集都被分到叶节点上,即都有了明确的类,这样就生成了一颗决策树。
以上方法就是决策树学习中的特征选择和决策树生成,这样生成的决策树可能对训练数据有很好的分类能力,但对未知的测试数据却未必有很好的分类能力,即可能发生过拟合现象。我们需要对已生成的树自下而上进行剪枝,将树变得更简单,从而使其具有更好的泛化能力。具体地,就是去掉过于细分的叶结点,使其回退到父结点,甚至更高的结点,然后将父结点或更高的结点改为新的叶结点,从而使得模型有较好的泛化能力。。
决策树生成和决策树剪枝是个相对的过程,决策树生成旨在得到对于当前子数据集最好的分类效果(局部最优),而决策树剪枝则是考虑全局最优,增强泛化能力。
在对此有一定了解之后,我们先看看,如何在sklearn中将决策树用起来。然后再学习其中的细节。

示例代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

iris = datasets.load_iris()
X = iris.data[:,2:] # iris有四个特征,这里取后两个,形成一个坐标点
y = iris.target
# 绘图
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
plt.scatter(X[y==2,0],X[y==2,1])
plt.show()

from sklearn.tree import DecisionTreeClassifier
# 创建决策树对象,最大深度max_depth为2层,criterion评判标准为entropy(熵)
dt_clt = DecisionTreeClassifier(max_depth=2,criterion='entropy')
# 将训练数据送给模型
dt_clt.fit(X,y)

# 绘制决策边界
def plot_decision_boundary(model, axis): # model是模型,axis是范围
    x0, x1 = np.meshgrid(
        np.linspace(axis[0], axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
        np.linspace(axis[2], axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1),
    )
    X_new = np.c_[x0.ravel(), x1.ravel()]

    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)

    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
    
    plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)

# 数据可视化    
plot_decision_boundary(dt_clt, axis=[0.5,7.5,0,3])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
plt.scatter(X[y==2,0],X[y==2,1])
plt.show()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值