本文为CART树实现,基于python3语言, 参考的博客为http://www.dmlearning.cn/single/6362dfbeddd6448c9ff1ceaf6eec0ef9.html
首先向原博客作者表示感谢,其次这是本人学习机器学习算法第一次实现机器学习算法,感觉对CART树理解又进了一步,(本人强烈建议把常用的机器学习算法能够尽量的实现一遍,会对公式的理解和其他细节更加清晰明了。)CART树的概念和公式网上搜索一大把,本文只是介绍CART树具体实现。如有错误请不吝赐教,谢谢
import numpy as np
import pandas as pd
class TreeNode(object):
"""docstring for TreeNode"""
def __init__(self, trueBranch = None, falseBreach = None, value = None, result = None):
self.trueBranch = trueBranch
self.falseBreach = falseBreach
self.value = value
self.result = result
data = pd.read_csv('./fishiris.csv')
y = data['Name']
x = data.drop(["Name"], axis=1 )
def calculateDiffCount(datas):
result = {}
for data in datas:
if(data not in result):
result[data] = 1
else:
result[data] += 1
return result
#计算每一个feature下所选元素的基尼系数
def Gini(datas):
length = len(datas)
imp = 0.0
counts = calculateDiffCount(datas)
for i in counts:
imp += np.square(counts[i] / length)
return 1 - imp
#构建CART树,使用递归方法构建分类回归树,其中在所分类的基尼指数增益为0时停止分裂。
def bulidDecisionTree(x, y):
currentGini = Gini(y)
print('currentGini = ', currentGini)
bestGini = 0.0
colnums = len(x[0])
rows = len(x)
for col in range(colnums):
val_set = set(x[col] for x in x)
for val in val_set:
listx1, listx2, listy1, listy2 = SplitData(val, col, x, y)
p = len(listy1) / rows
gain = currentGini - p * Gini(listy1) - (1 -p) *Gini(listy2)
print('gain = ', gain)
if gain > bestGini: #选择基尼指数增益最大为分裂点
bestGini = gain
Bestvalue = val
bestlist = [listx1, listy1, listx2, listy2]
if bestGini > 0:
trueBranch = bulidDecisionTree(bestlist[0], bestlist[1])
falseBranch = bulidDecisionTree(bestlist[2], bestlist[3])
return TreeNode(value = Bestvalue, trueBranch = trueBranch, falseBreach = falseBranch)
else:
return TreeNode(result = list(zip(x, y)))
def SplitData(value, col, x, y):
listx1 = []
listx2 = []
listy1 = []
listy2 = []
for (tempx, tempy) in zip(x, y):
if(tempx[col] >= value):
listx1.append(tempx)
listy1.append(tempy)
else:
listx2.append(tempx)
listy2.append(tempy)
return listx1, listx2, listy1, listy2
xdata = x.values.tolist()
ydata = y.values.tolist()
root = bulidDecisionTree(xdata, ydata)
首先核心的地方就是计算gini指数和生成树,其中 gini指数为def Gini(datas): 生成树为def bulidDecisionTree(x, y):
该程序还有很多未完善的地方,比如对string变量的gini指数增益计算,还有剪枝操作,还有对连续值进行gini指数计算时过程为先将所选feature的元素进行从小到大排列,后将两两元素取均值进行左右两树基尼指数增益的计算,选取最佳切分点,
该程序还不太完善,后续会继续完善该程序, 但对理解CART树的核心内容还是很有用的。谢谢