【机器学习笔记】手撕CART

个人博客:https://iyinst.github.io/2021/04/22/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0-%E6%89%8B%E6%92%95CART/

CART是分类和回归树,其原理不再赘述,本文给出一个手写的CART的思路和代码。

import numpy as np
from utils import se_loss, gini_loss,time_count


class Node:
    def __init__(self, value=None, left=None, right=None, instances_index=None):
        self.value = value
        self.left = left
        self.right = right
        self.instances_index = instances_index
        self.split_feature = None
        self.split_point = None


class CART:
    def __init__(self, objective='regression', max_depth=10, min_samples_leaf=2, min_impurity_decrease=0.):
        self.objective = objective
        if self.objective == 'regression':
            self.loss = se_loss
            self.leaf_weight = np.mean
        elif self.objective == 'classification':
            self.loss = gini_loss
            self.leaf_weight = lambda y: np.argmax(np.bincount(y))

        self.min_impurity_decrease = min_impurity_decrease
        self.max_depth = max_depth
        self.root = Node()
        self.min_samples_leaf = min_samples_leaf
        self.depth = 1

    # @time_count
    def fit(self, X, y):
        self.root.instances_index = list(range(X.shape[0]))
        self._generate_node(self.root, X, y, self.depth)

    def _generate_node(self, root: Node, X: np.array, y: np.array, depth: int):

        # 大于最大深度剪枝
        self.depth = max(depth, self.depth)
        if depth >= self.max_depth:
            root.value = self.leaf_weight(y[root.instances_index])
            return

        split_feature, split_point = -1, -1
        min_loss = self.loss(y[root.instances_index])

        # 寻找分裂点
        for feature_index in range(X.shape[1]):
            split_candidate = sorted(np.unique(X[root.instances_index, feature_index]))
            for candidate in split_candidate:
                left = [i for i in root.instances_index if X[i, feature_index] <= candidate]
                right = [i for i in root
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值