CART树实现

本文为CART树实现,基于python3语言,  参考的博客为http://www.dmlearning.cn/single/6362dfbeddd6448c9ff1ceaf6eec0ef9.html

首先向原博客作者表示感谢,其次这是本人学习机器学习算法第一次实现机器学习算法,感觉对CART树理解又进了一步,(本人强烈建议把常用的机器学习算法能够尽量的实现一遍,会对公式的理解和其他细节更加清晰明了。)CART树的概念和公式网上搜索一大把,本文只是介绍CART树具体实现。如有错误请不吝赐教,谢谢

import numpy as np 
import pandas as pd 

class TreeNode(object):
	"""docstring for TreeNode"""
	def __init__(self, trueBranch = None, falseBreach = None, value = None, result = None):
		self.trueBranch = trueBranch
		self.falseBreach = falseBreach
		self.value = value
		self.result = result

data = pd.read_csv('./fishiris.csv')
y = data['Name']
x = data.drop(["Name"], axis=1 )

def calculateDiffCount(datas):
	result = {}
	for data in datas:
		if(data not in result):
			result[data] = 1
		else: 
			result[data] += 1
	return result
#计算每一个feature下所选元素的基尼系数
def Gini(datas):
	length = len(datas)
	imp = 0.0
	counts = calculateDiffCount(datas)
	for i in counts:
		imp += np.square(counts[i] / length)
	return 1 - imp
#构建CART树,使用递归方法构建分类回归树,其中在所分类的基尼指数增益为0时停止分裂。
def bulidDecisionTree(x, y):
	currentGini = Gini(y)
	print('currentGini = ', currentGini)
	bestGini = 0.0
	colnums = len(x[0])
	rows = len(x)
	for col in range(colnums):
		val_set = set(x[col] for x in x)
		for val in val_set:
			listx1, listx2, listy1, listy2 = SplitData(val, col, x, y)
			p = len(listy1) / rows
			gain = currentGini - p * Gini(listy1) - (1 -p) *Gini(listy2)
			print('gain = ', gain)
			if gain > bestGini:         #选择基尼指数增益最大为分裂点
				bestGini = gain
				Bestvalue = val
				bestlist = [listx1, listy1, listx2, listy2]

	if bestGini > 0:
		trueBranch = bulidDecisionTree(bestlist[0], bestlist[1])
		falseBranch = bulidDecisionTree(bestlist[2], bestlist[3])

		return TreeNode(value = Bestvalue, trueBranch = trueBranch, falseBreach = falseBranch)
	else:
		return TreeNode(result = list(zip(x, y)))


def SplitData(value, col,  x, y):
	listx1 = []
	listx2 = []
	listy1 = []
	listy2 = []

	for (tempx, tempy) in zip(x, y):
		if(tempx[col] >= value):
			listx1.append(tempx)
			listy1.append(tempy)
		else:
			listx2.append(tempx)
			listy2.append(tempy)
	return listx1, listx2, listy1, listy2


xdata = x.values.tolist()
ydata = y.values.tolist()
root = bulidDecisionTree(xdata, ydata)

首先核心的地方就是计算gini指数和生成树,其中 gini指数为def Gini(datas): 生成树为def bulidDecisionTree(x, y):

该程序还有很多未完善的地方,比如对string变量的gini指数增益计算,还有剪枝操作,还有对连续值进行gini指数计算时过程为先将所选feature的元素进行从小到大排列,后将两两元素取均值进行左右两树基尼指数增益的计算,选取最佳切分点,

该程序还不太完善,后续会继续完善该程序, 但对理解CART树的核心内容还是很有用的。谢谢

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值