![d67b84db5da869f7ac81f556b1eff5b0.png](https://i-blog.csdnimg.cn/blog_migrate/3b856868dad67f298a80adac28afe8d6.jpeg)
提到回归树相信大家应该都不会觉得陌生(不陌生你点进来干嘛[捂脸]),大名鼎鼎的GBDT算法就是用回归树组合而成的。本文就回归树的基本原理进行讲解,并手把手、肩并肩地带您实现这一算法。
完整实现代码请参考本人的p...哦不是...github:
regression_tree.pygithub.com regression_tree_example.pygithub.com1. 原理篇
我们用人话而不是大段的数学公式来讲讲回归树是怎么一回事。
1.1 最简单的模型
如果预测某个连续变量的大小,最简单的模型之一就是用平均值。比如同事的平均年龄是28岁,那么新来了一批同事,在不知道这些同事的任何信息的情况下,直觉上用平均值28来预测是比较准确的,至少比0岁或者100岁要靠谱一些。我们不妨证明一下我们的直觉:
- 定义损失函数L,其中y_hat是对y预测值,使用MSE来评估损失:
- 对y_hat求导:
- 令导数等于0,最小化MSE,则:
- 所以,
结论,如果要用一个常量来预测y,用y的均值是一个最佳的选择。
1.2 加一点难度
仍然是预测同事年龄,这次我们预先知道了同事的职级,假设职级的范围是整数1-10,如何能让这个信息帮助我们更加准确的预测年龄呢?
一个思路是根据职级把同事分为两组,这两组分别应用我们之前提到的“平均值”模型。比如职级小于5的同事分到A组,大于或等于5的分到B组,A组的平均年龄是25岁,B组的平均年龄是35岁。如果新来了一个同事,职级是3,应该被分到A组,我们就预测他的年龄是25岁。
1.3 最佳分割点
还有一个问题待解决,如何取一个最佳的分割点对不同职级的同事进行分组呢? 我们尝试所有m个可能的分割点P_i,沿用之前的损失函数,计算Loss得到L_i。最小的L_i所对应的P_i就是我们要找的“最佳分割点”。
1.4 运用多个变量
再复杂一些,如果我们不仅仅知道了同事的职级,还知道了同事的工资(貌似不科学),该如何预测同事的年龄呢?
我们可以分别根据职级、工资计算出职级和工资的最佳分割点P_1, P_2,对应的Loss L_1, L_2。然后比较L_1和L2,取较小者。假设L_1 < L_2,那么按照P_1把不同职级的同事分为A、B两组。在A、B组内分别计算工资所对应的分割点,再分为C、D两组。这样我们就得到了AC, AD, BC, BD四组同事以及对应的平均年龄用于预测。
1.5 答案揭晓
如何实现这种1 to 2, 2 to 4, 4 to 8的算法呢?
熟悉数据结构的同学自然会想到二叉树,这种树被称为回归树,顾名思义利用树形结构求解回归问题。
2. 实现篇
本人用全宇宙最简单的编程语言——Python实现了回归树算法,便于学习和使用。简单说明一下实现过程,更详细的注释请参考本人github上的代码。
2.1 创建Node类
初始化,存储预测值、左右结点、特征和分割点
class
2.2 创建回归树类
初始化,存储根节点和树的高度。
class
2.3 计算分割点、MSE
根据自变量col、因变量label以及分割点split,计算分割后的MSE。
@staticmethod
2.4 计算最佳分割点
遍历特征某一列的所有的不重复的点,找出MSE最小的点作为最佳分割点。如果特征中没有不重复的元素则返回None。
def
2.5 选择最佳特征
遍历所有特征,计算最佳分割点对应的MSE,找出MSE最小的特征、对应的分割点,左右子节点对应的均值。如果所有的特征都没有不重复元素则返回None
def
2.6 规则转文字
将规则用文字表达出来,方便我们查看规则。
@staticmethod
2.7 获取规则
将回归树的所有规则都用文字表达出来,方便我们了解树的全貌。这里用到了队列+广度优先搜索。有兴趣也可以试试递归或者深度优先搜索。
def
2.8 训练模型
仍然使用队列+广度优先搜索,训练模型的过程中需要注意:
1. 控制树的最大深度max_depth;
2. 控制分裂时最少的样本量min_samples_split;
3. 叶子结点至少有两个不重复的y值;
4. 至少有一个特征是没有重复值的。
def
2.9 打印规则
模型训练完毕,查看一下模型生成的规则
def
2.10 预测一个样本
def
2.11 预测多个样本
def
3 效果评估
3.1 main函数
使用著名的波士顿房价数据集,按照7:3的比例拆分为训练集和测试集,训练模型,并统计准确度。
def
3.2 效果展示
最终生成了15条规则,拟合优度0.776,运行时间634毫秒,效果还算不错~
![70e2ff2ba3e14fc2f7ba59bbe3ed9f5e.png](https://i-blog.csdnimg.cn/blog_migrate/aac33066e41460c1d9b8582ee5a522f2.jpeg)
3.3 工具函数
本人自定义了一些工具函数,可以在github上查看
utils
- run_time - 测试函数运行时间
- load_boston_house_prices - 加载波士顿房价数据
- train_test_split - 拆分训练集、测试集
- get_r2 - 计算拟合优度
总结
回归树的原理:
损失最小化,平均值大法。 最佳行与列,效果顶呱呱。
回归树的实现:
一顿操作猛如虎,加减乘除二叉树。