CART回归树原理及python实现

11 篇文章 1 订阅
3 篇文章 0 订阅


ID3、C4.5不能直接处理连续性的特征,需要将连续性的转化成离散的,但是会破坏连续性特征的内在结构。

一、什么是CART回归算法

CART分类回归树(Classification and Regression Tree)是一种典型的二叉决策树,可以做分类或者回归。如果待预测结果是离散型数据,则CART生成分类决策树;如果待预测结果是连续型数据,则CART生成回归决策树,类似于二分类树,内部结点特征的取值只有“是”和“否”,左分支是取值为“是”的分支,右分支是取值为“否”的分支。

数据对象的属性特征为离散型或连续型,并不是区别分类树与回归树的标准。

二、分类树与回归树区别

分类树与回归树的区别在样本的输出,如果样本输出是离散值,这是分类树;样本输出是连续值,这是回归树。分类树的输出是样本的类别,回归树的输出是一个实数。

分类模型:采用基尼系数的大小度量特征各个划分点的优劣。
回归模型:采用误差平方和度量

三、CART回归树与一般回归算法区别

CART回归树不同于线性回归模型,不是通过拟合所有的样本点来得到一个最终模型进行预测,它是一类基于局部的回归算法,通过采用一种二分递归分割的技术将数据集切分成多份,每份子数据集中的标签值分布得比较集中(比如以数据集的方差作为数据分布比较集中的指标),然后采用该数据集的平均值作为其预测值。这样,CART回归树算法也可以较好地拟合非线性数据。

假如数据集的标签(目标值)的集合呈现非线性目标函数的值,CART回归树算法将数据集切分成很多份,即将函数切成一小段一小段的,对于每一小段的值是较为接近的,可以每一小段的平均值作为该小段的目标值。

四、CART回归树创建

4.1 CART回归树的划分

在CART分类树中,是利用Gini指数作为划分的指标,通过样本中的特征对样本进行划分,直到所有的叶节点中的所有样本均为一个类别为止。其中,Gini指数表示的是数据的混乱程度,对于回归树,样本标签是连续数据,当数据分布比较分散时,各个数据与平均值的差的平方和较大,方差就较大;当数据分布比较集中时,各个数据与平均值的差的平方和较小。方差越大,数据的波动越大;方差越小,数据的波动就越小。因此,对于连续的数据,可以使用样本与平均值的差的平方和作为划分回归树的指标。

假设,有 m m m个训练样本, { ( X ( 1 ) , y ( 1 ) ) , ( X ( 2 ) , y ( 2 ) ) , … , ( X ( m ) , y ( m ) ) } \left\{\left(X^{(1)}, y^{(1)}\right),\left(X^{(2)}, y^{(2)}\right), \ldots,\left(X^{(m)}, y^{(m)}\right)\right\} {(X(1),y(1)),(X(2),y(2)),,(X(m),y(m))},则划分CART回归树的指标为:
m ∗ s 2 = ∑ i = 1 m ( y ( i ) − y ˉ ) 2 m * s^{2}=\sum_{i=1}^{m}\left(y^{(i)}-\bar{y}\right)^{2} ms2=i=1m(y(i)yˉ)2
下面是用Python实现CART回归树的划分指标:

import numpy as np
def calculate_err(data):
	"""
	input: data(list)
	output: m*s^2(float)
	"""
	data = np.mat(data)
	return np.var(data[:,-1]) * data.shape[0]

有了划分的标准,那么应该如何对样本进行划分呢?与CART分类树的划分一样,遍历各特征的所有取值,尝试将样本划分到树节点的左右子树中。只是因为不同的划分标准,在选择划分特征和特征值时的比较会有差异而已。 下面是左右子树划分的代码:

def split_tree(data, fea, value):
    '''根据特征fea中的值value将数据集data划分成左右子树
    input:  data(list):数据集
            fea(int):待分割特征的索引
            value(float):待分割的特征的具体值
    output: (set1,set2)(tuple):分割后的左右子树
    '''
    set_1 = []
    set_2 = []
    for x in data:
        if x[fea] >= value:
            set_1.append(x)
        else:
            set_2.append(x)
    return (set_1, set_2)

4.2 CART回归树的构建

CART回归树的构建也类似于CART分类树,主要的不同有三方面:

  1. 在选择划分特征与特征值的比较时,不是计算Gini指数,而是计算被划分后两个子数据集中各样本与平均值的差的平方和,选择此值较小的情况对数据集进行划分。
  2. 针对每一个叶节点,不是取样本的类别,而是各样本的标签值的平均平均值作为预测结果。
  3. 最后,CART回归树可通过设置参数进行前剪枝操作,此次构建中有设置了min_sample和min_err来控制树的节点是否需要进一步划分。
class node:
    '''树的节点的类
    '''
    def __init__(self, fea=-1, value=None, results=None, right=None, left=None):
        self.fea = fea  # 用于切分数据集的属性的列索引值
        self.value = value  # 设置划分的值
        self.results = results  # 存储叶节点所属的类别
        self.right = right  # 右子树
        self.left = left  # 左子树

def build_tree(data, min_sample, min_err):
    '''构建树
    input:  data(list):训练样本
    		min_sample(int): 叶子节点中最少的样本数
    		min_err(float): 最小的error
    output: node:树的根结点
    '''
    # 构建决策树,函数返回该决策树的根节点
    if len(data) <= min_sample:
        return node(results=leaf(data)
    
    # 1、初始化
    bestError = calculate_err(data)
    bestCriteria = None  # 存储最佳切分属性以及最佳切分点
    bestSets = None  # 存储切分后的两个数据集
    
     # 2、构建回归树 
    feature_num = len(data[0]) - 1  # 样本中特征的个数
    for fea in range(0, feature_num):
    	feature_values = {}
    	for sample in data:
    		feature_values[sample[fea]] = 1
    	
    	for value in feature_values.keys():
    		# 2.1 尝试划分
    		(set_1, set_2) = split_tree(data, fea, value)
    		if len(set_1) < 2 or len(set_2) < 2:
    			continue
    		# 2.2 计算划分后的error
    		nowError = calculate_err(set_1) + calculate_err(set_2)
			if nowError < bestError and len(set_1) > 0 and len(set_2) > 0:
				bestError = nowError
				bestCriteria = (fea, value)
				bestSets = (set_1, set_2)
    
    # 3、判断划分是否结束
    if bestError > min_err:
        right = build_tree(bestSets[0],min_sample, min_err)
        left = build_tree(bestSets[1],min_sample, min_err)
        return node(fea=bestCriteria[0], value=bestCriteria[1], right=right, left=left)
    else:
        return node(results=leaf(data))  


def leaf(data):
	"""
	计算叶节点的平均值
	"""
	data = np.mat(data)
	return np.mean(data[:,-1])

4.3 CART回归树的剪枝

在CART回归树中,当树中的节点对样本一直划分下去时,会出现最极端的情况:每一个叶子节点中仅包含一个样本,此时,叶子节点的值即为该样本的标签均值。这种情况极易对数据过拟合,为防止发生过拟合,需要对CART回归树进行剪枝,以防止生成过多的叶子节点。
在剪枝中主要分为:预剪枝和后剪枝。

  1. 预剪枝是指在生成树的过程中对树的深度进行控制,防止生成过多的叶子节点。在build_tree函数中就使用了min_sample和min_err来控制树中的节点是否需要进行更多的划分。通过不断调节这两个参数来找到合适的CART树模型。
  2. 后剪枝是指将训练样本分成两个部分,一部分用来训练CART树模型,这部分数据被称为训练数据,另一部分用来对生成的树模型进行剪枝,称为验证数据。
    在后剪枝的过程中,通过验证生成的CART树模型是否在验证数据集上发生过拟合,如果出现过拟合的现象,则合并一些叶子节点来达到CART树模型的剪枝。

本文中主要使用的是预剪枝,通过调整min_sample和min_err参数的方式。

4.4 数据预测

CART回归树模型构建好后,利用训练数据来训练该模型,最后训练好的回归树模型需要进行评估,了解预测值与实际值间的差距是否在接受范围内。

对CART回归树进行评估时,因需要对数据集中各样本进行预测,然后利用预测值与原始样本的标签值计算残差,所以,首先要建立predict函数。

def predict(sample,tree):
	"""对每一个样本sample进行预测
	input: sample(list)
	output: results(float)
	"""
	# 如果只是树根
	if tree.results != None:
		return tree.results
	else:
		# 有子树
		val_sample = sample(tree.fea)
		branch = None
		if val_sample >= tree.value:
			branch = tree.right
		else:
			branch = tree.left
		return predict(sample, branch)

接下来,对数据集进行预测,并计算残差,代码如下:

def evaluate_error(data, tree):
	"""评估CART回归树模型
	input: data(list)
			tree: 训练好的CART回归树模型
	output: total_error/m(float): 均方误差
	"""
	m = len(data)
	total_error = 0.0
	for i in range(m):
		sample = data[i,:-1]
		pred = predict(sample, tree)
		total_error += np.square(data[i,-1] - pred)
	return total_error / m

五、总结

CART回归树算法是用来预测目标为连续值的算法,是一类基于局部的回归算法。CART回归树的构建是先利用类似CART分类树的方法将数据集进行划分,但划分的标准不同,本文使用的指标是各数值与平均值的差的平方和,在划分时选择使划分后的左右子树的该指标之和较小的特征与特征值,直到最后叶子节点中的样本个数达到min_sample或者各数值与平均值的差的平方和达到min_err为止,最后基于该叶子节点中的平均值作为预测值。在训练模型中,为避免过拟合,使用了参数min_sample和min_err来控制CART回归树模型的生成。

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值