分类回归——组合算法之提升2:提升树以及Python实现

提升树是以分类树或回归树为基本分类器的提升算法。每个基本分类器的预测结果并不是最终结果,仅仅是最终结果的一个累加量。

算法简介

上一篇博客讲到,提升算法需要解决两个基本问题:

  1. 更改训练数据集或更改训练数据集的权重,以便下一基本分类器预测。
  2. 各个基本分类器之间如何线性组合 。
    以CART回归树为例,其解决这两个问题的方法就是:
  3. 以上一轮的组合树预测结果的残差作为当前的训练数据集y。
  4. 对回归而言,平方误差逐渐减小,故对于基本分类器而言,没有权重差别。

算法流程

  • Input: 训练数据集(X, y), 阈值epsilon
  • Output: 组合树
  • Step1: 构建候选切分点,以及存储对应的切分后左、右区域的样本索引(需要多次使用,故预先存储)
  • Step2: 根据CART回归树算法,挑选出平方误差最小的切分点,若平方误差小于epsilon则终止构建树,反之构建一个只含有两个叶节点的CART树,并转到步骤3
  • Step3: 基于当前的组合树,计算预测每个样本的预测残差,并作为下一棵树的训练集的y, 转步骤2

代码

"""
提升树:基于二叉回归树的提升算法
程序暂考虑输入为一维的情况
"""
from collections import defaultdict
import numpy as np


class BoostingTree:
    def __init__(self, epsilon=1e-2):
        self.epsilon = epsilon
        self.cand_splits = []  # 候选切分点
        self.split_index = defaultdict(tuple)  # 由于要多次切分数据集,故预先存储,切分后数据点的索引
        self.split_list = []  # 最终各个基本回归树的切分点
        self.c1_list = []  # 切分点左区域取值
        self.c2_list = []  # 切分点右区域取值
        self.N = None
        self.n_split = None

    def init_param(self, X_data):
        # 初始化参数
        self.N = X_data.shape[0]
        for i in range(1, self.N):
            self.cand_splits.append((X_data[i][0] + X_data[i - 1][0]) / 2)
        self.n_split = len(self.cand_splits)
        for split in self.cand_splits:
            left_index = np.where(X_data[:, 0]<= split)[0]
            right_index = list(set(range(self.N))-set(left_index))
            self.split_index[split] = (left_index, right_index)
        return

    def _cal_err(self, split, y_res):
        # 计算每个切分点的误差
        inds = self.split_index[split]
        left = y_res[inds[0]]
        right = y_res[inds[1]]

        c1 = np.sum(left) / len(left)
        c2 = np.sum(right) / len(right)
        y_res_left = left - c1
        y_res_right = right - c2
        res = np.hstack([y_res_left, y_res_right])
        res_square = np.apply_along_axis(lambda x: x ** 2, 0, res).sum()
        return res_square, c1, c2

    def best_split(self,y_res):
        # 获取最佳切分点,并返回对应的残差
        best_split = self.cand_splits[0]
        min_res_square, best_c1, best_c2 = self._cal_err(best_split, y_res)

        for i in range(1, self.n_split):
            res_square, c1, c2 = self._cal_err(self.cand_splits[i], y_res)
            if res_square < min_res_square:
                best_split = self.cand_splits[i]
                min_res_square = res_square
                best_c1 = c1
                best_c2 = c2

        self.split_list.append(best_split)
        self.c1_list.append(best_c1)
        self.c2_list.append(best_c2)
        return

    def _fx(self, X):
        # 基于当前组合树,预测X的输出值
        s = 0
        for split, c1, c2 in zip(self.split_list, self.c1_list, self.c2_list):
            if X < split:
                s += c1
            else:
                s += c2
        return s

    def update_y(self, X_data, y_data):
        # 每添加一颗回归树,就要更新y,即基于当前组合回归树的预测残差
        y_res = []
        for X, y in zip(X_data, y_data):
            y_res.append(y - self._fx(X[0]))
        y_res = np.array(y_res)
        res_square = np.apply_along_axis(lambda x: x ** 2, 0, y_res).sum()
        return y_res, res_square

    def fit(self, X_data, y_data):
        self.init_param(X_data)
        y_res = y_data
        while True:
            self.best_split(y_res)
            y_res, res_square = self.update_y(X_data, y_data)
            if res_square < self.epsilon:
                break
        return

    def predict(self, X):
        return self._fx(X)


if __name__ == '__main__':
    # data = np.array(
    #     [[1, 5.56], [2, 5.70], [3, 5.91], [4, 6.40], [5, 6.80], [6, 7.05], [7, 8.90], [8, 8.70], [9, 9.00], [10, 9.05]])
    # X_data = data[:, :-1]
    # y_data = data[:, -1]
    # BT = BoostingTree(epsilon=0.18)
    # BT.fit(X_data, y_data)
    # print(BT.split_list, BT.c1_list, BT.c2_list)
    X_data_raw = np.linspace(-5, 5, 100)
    X_data = np.transpose([X_data_raw])
    y_data = np.sin(X_data_raw)
    BT = BoostingTree(epsilon=0.1)
    BT.fit(X_data, y_data)
    y_pred = [BT.predict(X) for X in X_data]

    import matplotlib.pyplot as plt

    p1 = plt.scatter(X_data_raw, y_data, color='r')
    p2 = plt.scatter(X_data_raw, y_pred, color='b')
    plt.legend([p1, p2], ['real', 'pred'])
    plt.show()

回归拟合结果
预测结果
我的GitHub
注:如有不当之处,请指正。

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值