根据决策树分而治之的思想,使用gini准则,封装一个决策树分类算法,同时能实现调节两个超参数:树深和叶子节点最小样本数。
代码部分
import numpy as np
from collections import Counter
'''
Encapsulate the decision tree method
author:Evan
'''
class DecisionTreeClassifier:
def __init__(self,max_depth=2,min_samples_leaf=1):
self.tree_ = None
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
def fit(self, X, y):
self.tree_ = self.creat_tree(X,y)
return self
def creat_tree(self, X, y,current_depth=1):
if current_depth>self.max_depth:
return None
d, v, g = try_split(X, y,self.min_samples_leaf)
if d == -1 or g == 0:
return None
node = Node(d, v, g)
X_left, X_right, y_left, y_right = cut(X, y, v, d)
node.children_left = self.creat_tree(X_left, y_left,current_depth+1)
if node.children_left is None:
label = Counter(y_left).most_common(1)[0][0]
node.children_left = Node(l=label)
node.children_right = self.creat_tree(X_right, y_right,current_depth+1)
if node.children_right is None:
label = Counter(y_right).most_common(1)[0][0]
node.children_right = Node(l=label)
return node
def predict(self, X):
assert self.tree_ is not None, 'Call the fit() method first'
return np.array([self._predict(x, self.tree_) for x in X])
def _predict(self, x, node):
if node.label is not None:
return node.label
if x[node.dim] <= node.value:
return self._predict(x, node.children_left)
else:
return self._predict(x, node.children_right)
def cut(X,y,v,d):
'''将数据一分为二'''
ind_left = (X[:,d]<=v)
ind_right = (X[:,d]>v)
return X[ind_left], X[ind_right], y[ind_left], y[ind_right]
def try_split(X,y,min_samples_leaf):
'''划分数据集,返回最好的划分点'''
best_g = 1
best_d = -1
best_v = -1
for d in range(X.shape[1]):
sorted_index = np.argsort(X[:,d])
for i in range(len(X)-1):
if X[sorted_index[i],d] == X[sorted_index[i+1],d]:
continue
v = (X[sorted_index[i],d]+X[sorted_index[i+1],d])/2
# print("d={},v={}".format(d,v))
X_left,X_right,y_left,y_right = cut(X,y,v,d)
gini_all =gini(y_left)+gini(y_right)
# print("d={},v={},g={}".format(d,v,gini_all))
if gini_all<best_g and len(y_left)>= min_samples_leaf and len(y_right) >= min_samples_leaf:
best_g = gini_all
best_d = d
best_v = v
return best_d,best_v,best_g
# define node class
class Node():
def __init__(self, d=None, v=None, g=None, l=None):
self.dim = d
self.value = v
self.gini = g
self.label = l
self.children_left = None
self.children_right = None
def __repr__(self):
return "Node(d={},v={},g={},l={})".format(self.dim, self.value, self.gini, self.label)
# compute gini
def gini(y):
counter = Counter(y)
result = 0
for v in counter.values():
result += (v/len(y))**2
return 1 - result