Boosting Tree--回归树 python简单实现

Boosting Tree–回归树 python简单实现

1.先验知识

参考博客: https://blog.csdn.net/qiu931110/article/details/80748540

2.代码

class weakLearner():

    # 计算均方误差
    def __mse(self, left_node):

        return np.sum((np.average(left_node) - np.array(left_node)) ** 2)

    # 特征值从小到大排序好,错位相加,生成候选划分点
    def Get_stump_list(self, X):
        tmp1 = list(X.copy())
        tmp2 = list(X.copy())
        tmp1.insert(0, 0)
        tmp2.append(0)
        stump_list = ((np.array(tmp1) + np.array(tmp2)) / float(2))[1:-1]

        return stump_list

    #根据x与候选划分点进行比较,二分数据集
    def __binSplitData(self, stump_list, X, Y):
        left_node = []
        right_node = []
        for j in range(np.shape(X)[0]):
            if X[j] < stump_list:
                left_node.append(Y[j])
            else:
                right_node.append(Y[j])

        return left_node, right_node

    # 根据最小均方误差选择最佳划分点
    def __bestSplit(self, stump_list, X, Y):
        best_mse = np.inf
        for i in range(np.shape(stump_list)[0]):
            left_node, right_node = self.__binSplitData(stump_list[i], X, Y)
            left_mse = self.__mse(left_node)
            right_mse = self.__mse(right_node)

            if best_mse > (left_mse + right_mse):
                best_mse = left_mse + right_mse
                best_f_val = stump_list[i]
        return best_f_val

    # 建立CART树
    def __CART(self, X, Y):

        stump_list = self.Get_stump_list(X) #生成候选分割节点
        best_f_val = self.__bestSplit(stump_list, X, Y) #获得最佳分割节点
        tree = dict()
        tree['cut_f_val'] = best_f_val  # 最佳划分点
        left_node, right_node = self.__binSplitData(best_f_val, X, Y)
        tree['left_tree'] = left_node  # 存储左分支的所有样本的标签值
        tree['right_tree'] = right_node
        left_mse = self.__mse(left_node)
        right_mse = self.__mse(right_node)
        tree['left_mse'] = left_mse
        tree['right_mse'] = right_mse

        return tree

    # 训练CART树
    def train(self, X, Y):

        self.tree = self.__CART(X, Y)

        return self.tree

    #CART预测
    def predict(self, X):
        return np.array([self.__predict_one(x, self.tree) for x in X])

    #预测一个样本
    def __predict_one(self, x, tree):
        cut_val = tree['cut_f_val']
        Y_left = np.average(tree['left_tree'])
        Y_right = np.average(tree['right_tree'])

        result = Y_left if x <= cut_val else Y_right

        return result



#提升树的类
class Boosting_tree():
    def __init__(self, tree_num: int = 6, classifier = weakLearner):
        self.tree_num = tree_num
        self.weakLearner = weakLearner
        self.residual = []
        self.Trees = []

    def fit(self, X, Y):
        Tree_num = self.tree_num #默认为 6
        X = np.array(X)  # 把列表转化为数组
        Y = np.array(Y)

        residual = Y.copy()  # 初始化残差
        # 产生每一棵树
        for num in range(Tree_num):
            # 每次新生成树后,还需要再次更新残差residual
            wl = self.weakLearner() #实例化弱学习器
            Tree = wl.train(X, residual) #用上一颗树的残差拟合下一个树
            #计算当前模型的残差
            Y_left = np.average(Tree['left_tree']) #计算左叶子节点
            Y_right = np.average(Tree['right_tree']) #计算左叶子节点
            left_residual = np.array(Tree['left_tree']) - Y_left #计算左边的残差
            right_residual = np.array(Tree['right_tree']) - Y_right #计算右边的残差
            residual = np.append(left_residual, right_residual) #合并残差

            self.residual.append(residual) #存储每棵树的残差
            self.Trees.append(wl) #存储训练好的模型
        return self.Trees

    #模型预测
    def predict(self, X):

        M = self.tree_num #弱学习器的个数
        y_ = 0 #初始化预测值为0
        for m in range(M):
            y_ += self.Trees[m].predict(X)

        return y_

    #计算模型预测误差
    def error(self, Y, y_predict):

        return np.sum((Y - y_predict) ** 2)


if __name__ == "__main__":

    import numpy as np

    Y = np.array([5.56, 5.7, 5.91, 6.4, 6.8, 7.05, 8.9, 8.7, 9, 9.05])
    # 已经排好序了。实际情况中单一特征的数据或者多特征的数据,选择切分点的时候也像决策树一样选择
    X = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

    Trees_reg = Boosting_tree() #实例化提升树
    Trees = Trees_reg.fit(X, Y) #拟合模型


    y_predict = Trees_reg.predict(X) #预测
    print(y_predict)

    error = Trees_reg.error(Y, y_predict) #计算损失
    print("The error is ", error)

    print("residual of every tree==", Trees_reg.residual) #查看每棵树的残差






  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小凉爽&玉米粒

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值