在《机器学习学习笔记(18)----CART(Classification And Regression Tree)算法》,我们给出了CART分类树的特征划分的算法,接下来,我们用python实现一个CART分类树算法(cartctree.py)(参考自《Python机器学习算法:原理,实现与案例》):
import numpy as np
class CartClassificationTree:
class Node:
def __init__(self):
self.value = None
# 内部叶节点属性
self.feature_index = None
self.feature_value = None
self.left = None
self.right = None
def __init__(self, gini_threshold=0.01, gini_dec_threshold=0.,min_samples_split=2):
#基尼系数的阈值
self.gini_threshold = gini_threshold
#基尼系数降低的阈值
self.gini_dec_threshold = gini_dec_threshold
#数据集还可继续切分的最小样本数量
self.min_samples_split = min_samples_split
def _gini(self, y):
#计算基尼系数
values = np.unique(y)
s = 0.
for v in values:
y_sub = y[y == v]
s += (y_sub.size / y.size) **2
return 1-s
def _gini_split(self, y, feature, value):
#计算根据特征切分后的基尼系数
#根据特征的值将数据集拆分成两个子集
indices= feature > value
y1 = y[indices]
y2 = y[~indices]
#分别计算两个子集的基尼系数
gini1 = self._gini(y1)
gini2 = self._gini(y2)
#计算切分后的基尼系数
gini = (y1.size * gini1 + y2.size * gini2)/y.size
return gini
def _get_split_points(self, feature):
#获取一个连续特征值的所有切分点
#获取一个特征所有出现过的值并排序
values = np.unique(feature)
#切分点为values中相邻两个点的中点
split_points = [(v1+v2)/2 for v1, v2 in zip(values[:-1],values[1:])]
return split_points
def _select_feature(self, X, y):
#选择划分特征
#最佳切分特征的index
best_feature_index = None
#最佳切分点
best_split_value = None
min_gini = np.inf
_, n = X.shape
for feature_index in range(n):
#迭代每一个特征
feature = X[:, feature_index]
#获得一个特征的所有切分点
split_points = self._get_split_points(feature)
for value in split_points:
#迭代每一个切分点value,计算使用value切分后的数据集的基尼系数
gini = self._gini_split(y, feature, value)
#若找到更小的基尼系数,则更新切分特征
if gini < min_gini:
min_gini = gini
best_feature_index = feature_index
best_split_value = value
#判断切分后的基尼系数的降低是否超过阈值
if self._gini(y) - min_gini < self.gini_dec_threshold:
best_feature_index = None
best_split_value = None
return best_feature_index, best_split_value, min_gini
def _node_value(self, y):
#计算节点的值
#统计数据集中样本类标记的个数
labels_count = np.bincount(y)
#任何情况下节点值总等于数据集中样本最多的类标记
return np.argmax(labels_count)
def _build_tree(self, X, y):
#决策树构造算法(递归)
#创建节点
node = CartClassificationTree.Node()
#计算节点的值
node.value = self._node_value(y)
#若当前数据集样本数量小于最小切分数量min_samples_split,则返回叶节点
if y.size < self.min_samples_split:
return node
#若当前数据集的基尼系数小于阈值gini_threshold,则返回叶节点
if self._gini(y) < self.gini_threshold:
return node
#选择最佳切分特征
feature_index, feature_value, min_gini = self._select_feature(X,y)
if feature_index is not None:
#如果存在适合切分特征,则当前节点为子节点
node.feature_index = feature_index
node.feature_value = feature_value
#根据已选择特征及切分点将数据集划分成两个子集
feature = X[:, feature_index]
indices = feature > feature_value
X1, y1 = X[indices], y[indices]
X2, y2 = X[~indices], y[~indices]
#使用数据子集创建左右子树
node.left = self._build_tree(X1, y1)
node.right = self._build_tree(X2, y2)
return node
def _predict_one(self, x):
#搜索决策树,对单个实例进行预测
node = self.tree_
while node.left:
if x[node.feature_index] > node.feature_value:
node = node.left
else:
node = node.right
return node.value
def train(self, X_train, y_train):
#训练
self.tree_ = self._build_tree(X_train, y_train)
def predict(self, X):
#对每一个实例使用_predict_one,返回收集到的结果数组
return np.apply_along_axis(self._predict_one, axis=1, arr=X)
此程序适用于可用通过数字表示的特征值,这样通过上一篇文章根据连续特征值的寻找最佳分割阈值的方法,找到最适合的分割阈值,从而进一步找到最适合的特征列。通过递归调用,构造左右子树,最终得到整个决策树。
接下来,验证一下效果,选择鸢尾花数据集进行验证(https://archive.ics.uci.edu/ml/datasets/iris):
下载iris.data和iris.names两个文件。从iris.names文件可以了解数据集的属性:
列号 | 列名 | 特征/类标记 | 可取值 |
1 | sepal length | 特征 | 连续实数 |
2 | sepal width | 特征 | 连续实数 |
3 | petal length | 特征 | 连续实数 |
4 | petal width | 特征 | 连续实数 |
5 | class | 类标记 | Iris-versicolor Iris-virginica Iris-setosa |
该数据集共有150条记录,通过以下代码进行训练:
>>> import numpy as np
>>> X = np.genfromtxt('iris.data',delimiter=',',usecols=range(4),dtype=np.float)
>>> y = np.genfromtxt('iris.data',delimiter=',',usecols=4,dtype=np.str)
>>> y
array(['Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',
......
'Iris-virginica', 'Iris-virginica', 'Iris-virginica',
'Iris-virginica', 'Iris-virginica'], dtype='<U15')
>>> from sklearn.preprocessing import LabelEncoder
>>> le = LabelEncoder()
>>> y = le.fit_transform(y)
>>> y
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32)
>>> from cartctree import CartClassificationTree
>>> cct = CartClassificationTree()
>>> from sklearn.model_selection import train_test_split
>>> X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3)
>>> cct.train(X_train, y_train)
观察训练效果:
>>> from sklearn.metrics import accuracy_score
>>> y_predict = cct.predict(X_test)
>>> y_predict
array([2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 0,
2, 2, 0, 2, 2, 0, 1, 1, 0, 1, 2, 2, 0, 2, 2, 1, 0, 1, 0, 0, 2, 0,
1], dtype=int32)
>>> y_test
array([2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 0,
2, 2, 0, 2, 2, 0, 1, 1, 0, 1, 1, 2, 0, 2, 2, 1, 0, 1, 0, 0, 2, 0,
1], dtype=int32)
>>> accuracy_score(y_test, y_predict)
0.9555555555555556
准确率超过95%。
上面的训练的决策树结果还不够直观,我们需要绘制一颗决策树,上次在《机器学习学习笔记(16)----使用Matplotlib绘制决策树》,绘制过一颗决策树,但是那个Node节点的结构和这个CART分类树算法的Node节点不大一样,因此不能直接使用那篇文章的代码,需要修改适配一下(treeplotter2.py):
import matplotlib.pyplot as plt
from cartctree import CartClassificationTree
class TreePlotter2:
def __init__(self, tree, feature_names, label_names):
self.decision_node = dict(boxstyle="sawtooth", fc="0.8")
self.leaf_node = dict(boxstyle="round4", fc="0.8")
self.arrow_args = dict(arrowstyle="<-")
#保存决策树
self.tree = tree
#保存特征名字字典
self.feature_names=feature_names
#保存类标记名字字典
self.label_names=label_names
self.totalW = None
self.totalD = None
self.xOff = None
self.yOff = None
def _get_num_leafs(self, node):
'''获取叶节点的个数'''
if not node.left:
return 1
num_leafs = 0
num_leafs += self._get_num_leafs(node.left)
num_leafs += self._get_num_leafs(node.right)
return num_leafs
def _get_tree_depth(self, node):
'''获取树的深度'''
if not node.left:
return 1
max_depth = 0
this_depth1 = 1 + self._get_tree_depth(node.left)
this_depth2 = 1 + self._get_tree_depth(node.right)
if(this_depth1 > this_depth2):
max_depth = this_depth1
else:
max_depth = this_depth2
return max_depth
def _plot_mid_text(self, cntrpt, parentpt, txtstring, ax1) :
'''在父子节点之间填充文本信息'''
x_mid = (parentpt[0] - cntrpt[0])/2.0 + cntrpt[0]
y_mid = (parentpt[1] - cntrpt[1])/2.0 + cntrpt[1]
ax1.text(x_mid, y_mid, txtstring)
def _plot_node(self, nodetxt, centerpt, parentpt, nodetype, ax1):
ax1.annotate(nodetxt, xy= parentpt,\
xycoords= 'axes fraction',\
xytext=centerpt, textcoords='axes fraction',\
va="center", ha="center", bbox=nodetype, arrowprops= self.arrow_args)
def _plot_tree(self, tree, parentpt, nodetxt, ax1):
#子树的叶节点个数,总宽度
num_leafs = self._get_num_leafs(tree)
#子树的根节点名称
tree_name = self.feature_names[tree.feature_index]['name']
#计算子树根节点的位置
cntrpt = (self.xOff + (1.0 + float(num_leafs))/2.0/self.totalW, self.yOff)
#画子树根节点与父节点中间的文字
self._plot_mid_text(cntrpt, parentpt, nodetxt, ax1)
#画子树的根节点,与父节点间的连线,箭头。
self._plot_node(tree_name, cntrpt, parentpt, self.decision_node, ax1)
#计算下级节点的y轴位置
self.yOff = self.yOff - 1.0/self.totalD
if tree.left:
child = tree.left
if child.left:
#如果是子树,递归调用_plot_tree
self._plot_tree(child, cntrpt, self.feature_names[tree.feature_index]['value_names'][1]+str(tree.feature_value), ax1)
else:
#如果是叶子节点,计算叶子节点的x轴位置
self.xOff = self.xOff + 1.0/self.totalW
#如果是叶子节点,画叶子节点,以及叶子节点与父节点之间的连线,箭头。
self._plot_node(self.label_names[child.value], (self.xOff, self.yOff), cntrpt, self.leaf_node, ax1)
#如果是叶子节点,画叶子节点与父节点之间的中间文字。
self._plot_mid_text((self.xOff, self.yOff), cntrpt, self.feature_names[tree.feature_index]['value_names'][1]+str(tree.feature_value), ax1)
child = tree.right
if child.right:
#如果是子树,递归调用_plot_tree
self._plot_tree(child, cntrpt, self.feature_names[tree.feature_index]['value_names'][2]+str(tree.feature_value), ax1)
else:
#如果是叶子节点,计算叶子节点的x轴位置
self.xOff = self.xOff + 1.0/self.totalW
#如果是叶子节点,画叶子节点,以及叶子节点与父节点之间的连线,箭头。
self._plot_node(self.label_names[child.value], (self.xOff, self.yOff), cntrpt, self.leaf_node, ax1)
#如果是叶子节点,画叶子节点与父节点之间的中间文字。
self._plot_mid_text((self.xOff, self.yOff), cntrpt, self.feature_names[tree.feature_index]['value_names'][2]+str(tree.feature_value), ax1)
#还原self.yOff
self.yOff = self.yOff + 1.0/self.totalD
def create_plot(self):
fig = plt.figure(1, facecolor='white')
fig.clf()
#去掉边框
axprops=dict(xticks=[], yticks=[])
ax1 = plt.subplot(111, frameon=False, **axprops)
#树的叶节点个数,总宽度
self.totalW = float(self._get_num_leafs(self.tree))
#树的深度,总高度
self.totalD = float(self._get_tree_depth(self.tree))
self.xOff = -0.5/self.totalW
self.yOff = 1.0
#树根节点位置固定放在(0.5,1.0)位置,就是中央的最上方
self._plot_tree(self.tree, (0.5,1.0), '', ax1)
plt.show()
使用如下代码绘制上面的CART分类树:
>>> features_dict = {
0 : {'name' : 'sepal length',
'value_names': { 1: '>',
2: '<='}
},
1 : {'name' : 'sepal width',
'value_names': { 1: '>',
2: '<='}
},
2 : {'name' : 'petal length',
'value_names': { 1: '>',
2: '<='}
},
3 : {'name' : 'petal width',
'value_names': { 1: '>',
2: '<='}
}
}
>>> label_dict = {
0: 'Iris-setosa',
1: 'Iris-versicolor',
2: 'Iris-virginica'
}
>>> from treeplotter2 import TreePlotter2
>>> plotter = TreePlotter2(cct.tree_, features_dict, label_dict)
>>> plotter.create_plot()
可得CART二叉分类树如下图:
参考资料:
《Python机器学习算法:原理,实现与案例》 刘硕 著