个人博客:https://iyinst.github.io/2021/04/22/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0-%E6%89%8B%E6%92%95CART/
CART是分类和回归树,其原理不再赘述,本文给出一个手写的CART的思路和代码。
import numpy as np
from utils import se_loss, gini_loss,time_count
class Node:
def __init__(self, value=None, left=None, right=None, instances_index=None):
self.value = value
self.left = left
self.right = right
self.instances_index = instances_index
self.split_feature = None
self.split_point = None
class CART:
def __init__(self, objective='regression', max_depth=10, min_samples_leaf=2, min_impurity_decrease=0.):
self.objective = objective
if self.objective == 'regression':
self.loss = se_loss
self.leaf_weight = np.mean
elif self.objective == 'classification':
self.loss = gini_loss
self.leaf_weight = lambda y: np.argmax(np.bincount(y))
self.min_impurity_decrease = min_impurity_decrease
self.max_depth = max_depth
self.root = Node()
self.min_samples_leaf = min_samples_leaf
self.depth = 1
# @time_count
def fit(self, X, y):
self.root.instances_index = list(range(X.shape[0]))
self._generate_node(self.root, X, y, self.depth)
def _generate_node(self, root: Node, X: np.array, y: np.array, depth: int):
# 大于最大深度剪枝
self.depth = max(depth, self.depth)
if depth >= self.max_depth:
root.value = self.leaf_weight(y[root.instances_index])
return
split_feature, split_point = -1, -1
min_loss = self.loss(y[root.instances_index])
# 寻找分裂点
for feature_index in range(X.shape[1]):
split_candidate = sorted(np.unique(X[root.instances_index, feature_index]))
for candidate in split_candidate:
left = [i for i in root.instances_index if X[i, feature_index] <= candidate]
right = [i for i in root