回归树算法原理及实现

     作者:归辰

     来源:海边的拾遗者

由于现实中的很多问题是非线性的,当处理这类复杂的数据的回归问题时,特征之间的关系并不是简单的线性关系,此时,不可能利用全局的线性回归模型拟合这类数据。在上一篇文章"分类树算法原理及实现"中,分类树算法可以解决现实中非线性的分类问题,那么本文要讲的就是可以解决现实中非线性回归问题的回归树算法。

    本文以决策树中的CART树为例介绍回归树的原理及实现。

叶节点分裂指标

     通常在CART回归树中,样本的标签是一系列的连续值的集合,不能再使用基尼指数作为划分树的指标。在回归问题中我们可以发现,对于连续数据, 当数据分布比较分散时,各个数据与平均数的差的平方和较大,方差就较大;当数据分布比较集中时,各个数据与平均数的差的平方和较小。方差越大,数据的波动越大;方差越小,数据的波动就越小。因此,对于连续的数据,可以使用样本与平均值的差的平方和作为划分回归树的指标。

    方差是度量数据分布离散程度最常用的一种指标,对于包含m个训练样本的数据集D{(X(1),y(1)),(X(2),y(2)),…,(X(m),y(m))},则指标为数据集D中所有样本标签与均值的差的平方和:

    在回归树中即用该指标来进行叶节点分裂。现在让我们用代码将其实现。

import numpy as np
def err_cnt(dataSet):
    '''input: dataSet训练数据
    output: m*s^2总方差'''
    data = np.mat(dataSet)
    return np.var(data[:, -1]) * np.shape(data)[0]

回归树

     先定义样本被划分到左右子树的过程函数,原理为根据特征fea位置处的特征,按照值value将样本划分到左右子树中,当样本在特征fea处的值大于或者等于value时,将其划分到右子树中;否则,将其划分到左子树中。用代码实现如下:

def split_tree(data, fea, value):
    '''input: data训练样本
            fea需要划分的特征编号
            value指定的划分的值
    output: (set_1, set_2)左右子树的聚合'''
    set_1 = []  # 右子树的集合
    set_2 = []  # 左子树的集合
    for x in data:
        if x[fea] >= value:
            set_1.append(x)
        else:
            set_2.append(x)
    return (set_1, set_2) 

    另外需要定义计算当前叶子节点的值,计算的方法是使用划分到该叶子节点的所有样本的标签均值,代码如下: 

def leaf(dataSet):
    '''input: dataSet训练样本
    output: 均值'''
    data = np.mat(dataSet)
    return np.mean(data[:, -1])

    在按照特征对上述的数据进行划分的过程中,需要设置划分的终止条件和分类树比较类似。其构建过程可以分为以下几个步骤:

  • 对于当前训练数据集,遍历所有特征及其对应的所有可能切分点,寻找最佳切分特征及其最佳切分点,使得切分之后的各子集方差和最小,利用该最佳切分特征及其最佳切分点将训练数据集切分成两个子集,分别对应判别结果为左子树和判别结果为右子树。

  • 重复以下的步骤直至满足停止条件:为每一个叶子节点寻找最佳切分特征及其最佳切分点,将其划分为左右子树。 

  • 生成回归树。 

    现在先为树中的节点定义一个结构类,代码如下:

class node:
    def __init__(self, fea=-1, value=None, results=None, right=None, left=None):
        self.fea = fea  # 用于切分数据集的特征的列索引值
        self.value = value  # 设置划分的值
        self.results = results  # 存储叶节点的值
        self.right = right  # 右子树
        self.left = left  # 左子树

     然后我们可以利用递归的方法开始构建树了,在构建树的过程中,如果节点中的样本个数小于或者等于指定的最小样本数min_sample,则该节点不再划分。当节点需要划分时,首先计算当前节点的error值,划分后产生左子树和右子树,此时,计算左右子树的error值,若此时的error值小于最优的error值,则更新最优划分,当该节点划分完成后,继续对其左右子树进行划分。

def build_tree(data, min_sample, min_err):
    '''input: data训练样本
            min_sample叶子节点中最少样本数
            min_err最小的error
    output: node:树的根结点'''
    # 构建回归树,函数返回该树的根节点
    if len(data) <= min_sample:
        return node(results=leaf(data))
    
    # 1、初始化
    best_err = err_cnt(data)
    bestCriteria = None  # 存储最佳切分特征以及最佳切分点
    bestSets = None  # 存储切分后的两个数据集
    
    # 2、开始构建回归树
    feature_num = len(data[0]) - 1
    for fea in range(0, feature_num):
        feature_values = {}
        for sample in data:
            feature_values[sample[fea]] = 1
        
        for value in feature_values.keys():
            # 2.1、尝试划分
            (set_1, set_2) = split_tree(data, fea, value)
            if len(set_1) < 2 or len(set_2) < 2:
                continue
            # 2.2、计算划分后的error值
            now_err = err_cnt(set_1) + err_cnt(set_2)
            # 2.3、更新最优划分
            if now_err < best_err and len(set_1) > 0 and len(set_2) > 0:
                best_err = now_err
                bestCriteria = (fea, value)
                bestSets = (set_1, set_2)


    # 3、判断划分是否结束
    if best_err > min_err:
        right = build_tree(bestSets[0], min_sample, min_err)
        left = build_tree(bestSets[1], min_sample, min_err)
        return node(fea=bestCriteria[0], value=bestCriteria[1], \
                    right=right, left=left)
    else:
        return node(results=leaf(data))     

剪枝

     树回归中,当树中的节点对样本一直划分下去时,会出现的最极端的情况是:每一个叶子节点中仅包含一个样本,此时,叶子节点的值即为该样本的标签的值。这种情况易对训练样本"过拟合",通过这样方式训练出来的样本可以对训练样本拟合得很好,但是对新样本的预测效果将会较差,而这种问题一般大多发生在回归问题中。为了防止构建好的树模型过拟合,通常需要对回归树进行剪枝,剪枝的目的是防止回归树生成过多的叶子节点。在剪枝中主要分为:前剪枝和后剪枝。

    前剪枝是指在生成回归树的过程中对树的深度进行控制,防止生成过多的叶子节点。 

    后剪枝是指将训练样本分成两个部分,一部分用来训练树模型,这部分数据被称为训练数据,另一部分用来对生成的树模型进行剪枝,这部分数据被称为验证数据。如果出现过拟合的现象,则合并一些叶子节点来达到对树模型的剪枝。

     到这里整个流程基本就结束了~

◆ ◆ ◆  ◆ ◆

长按二维码关注我们


数据森麟公众号的交流群已经建立,许多小伙伴已经加入其中,感谢大家的支持。大家可以在群里交流关于数据分析&数据挖掘的相关内容,还没有加入的小伙伴可以扫描下方管理员二维码,进群前一定要关注公众号奥,关注后让管理员帮忙拉进群,期待大家的加入。

管理员二维码:

猜你喜欢

 笑死人不偿命的知乎沙雕问题排行榜

 用Python扒出B站那些“惊为天人”的阿婆主!

 全球股市跳水大战,谁最坑爹!

 华农兄弟、徐大Sao&李子柒?谁才是B站美食区的最强王者?

 你相信逛B站也能学编程

  • 2
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值