CART回归树&模型树 生成 剪枝 in Python

现实中,数据集中经常包含一些复杂的相互关系,使得输入数据和目标变量之间呈现非线性关系。对这些复杂的关系建模,一种可行的方式是使用树来对预测值进行分段,包括分段常数或者分段直线,即通过树结构对数据进行切分后,在叶节点上,对叶节点上的数据,取均值构造回归树,或者取线性模型构造模型树。

下面,我们统一将基于CART的回归树和模型树称作树回归。


1、树回归的特点


1.1 相对之前提到的ID3决策树来说,基于二元切分的树回归切分不会过快,而且可以处理连续性特征数据。

1.2 优点:可以对复杂和非线性数据建模

1.3 缺点:结果不像线性回归那么好理解

1.4 模型树可解释性由于回归树,相对而言,模型树也具有更高的预测准确度。


2、各种回归方法的比较


对于模型树、回归树和之前的线性回归,一种比较客观的比较方法是计算相关系数,即R^2值。

只需调用Numpy库中的命令corrcoef(yHat,y,rowvar=0)即可,其中yHat为模型预测值,y是目标变量的实际值。

R^2值越接近1.0说明预测性能越好。


3、几个主要函数伪代码


3.1 确定数据集切分的最佳位置 chooseBestSplit()函数

如果数据集中目标变量只有一种:
	不进行后续切分,直接将此数据集构建为叶节点	

对每个特征:
	对每个特征值:
		将数据集切分为两份
		计算切分后两个子数据集的误差和
		如果此误差和小于当前最小误差:
			将当前切分设定为最佳切分并更新最小误差

如果数据集上的误差和当前最小误差之间没有达到设定的容许误差下降值:
	不进行后续切分,直接将此数据集构建为叶节点	

如果切分后的子数据集中的样本数低于设定的最少样本数:
	不进行后续切分,直接将此数据集构建为叶节点	
			
返回记录的最佳切分的特征和切分点


3.2 树的生成算法 createTree()函数

调用chooseBestSplit()找到最佳待切分特征:
	如果该节点不能再分,即待切分特征无:
		将该节点存为叶节点
	执行二元切分
	在右子树调用createTree()函数
	在左子树调用createTree()函数

3.3 后剪枝算法 prune()函数

实际上,chooseBestSplit()函数中的三个如果已经对树生成过程进行了预剪枝,但此操作与算法设定的停止条件相关,不太好操作,还是要使用测试数据集进行后剪枝。

基于前面所得的树对测试数据进行切分:
	如果存在任一子集不是叶节点而是树:
		在该子集上调用prune()函数
	计算此时标准二分树的误差:
		即两个子叶节点上的误差和
	计算将当前两个叶节点合并后的误差:
		即当前标准二分树的根节点值取两叶节点均值后构成的单节点结构的误差
	如果合并后降低误差的话,就将此两叶节点进行合并


4、Python实现

from numpy import *

def loadDataSet(fileName):
	# creat a list, but following dataSet represents matrix
	dataMat = []
	fr = open(fileName)
	for line in fr.readlines():
		currLine = line.strip().split('\t')
		fltLine = map(float, currLine)
		dataMat.append(fltLine)
	return dataMat
	

### Preparing for creating tree
# the function regleaf and modelleaf is going to create the leafnodes	
def regLeaf(dataSet):
	return mean(dataSet[:,-1])
	
def regErr(dataSet):
	return shape(dataSet)[0] * var(dataSet[:,-1])
	# used for measuring the uniformity of data
	# or say for calculating the chaos of data 

def linearSolve(dataSet):
	N, n = shape(dataSet)
	X = mat(ones((N,n)))
	Y = mat(ones((N,1)))
	X[:, 1:n] = dataSet[:, 0:n-1]
	Y = dataSet[:, -1]
	xTx = X.T * X
	if linalg.det(xTx) == 0.0 :
		raise NameError('This matrix is singular, cannot do inverse, \n\
		try increasing the second value of ops')
	ws = xTx.I * (X.T * Y)
	return ws, X, Y
	
def modelLeaf(dataSet):
	ws, X, Y = linearSolve(dataSet)
	return ws

def modelErr(dataSet):
	ws, X, Y = linearSolve(dataSet)
	yHat = X * ws
	return sum(power(Y-yHat, 2))
	

### Creating tree
def binSplitDataSet(dataSet, feature, value):
	mat0 = dataSet[nonzero(dataSet[:,feature] >value)[0], :]
	mat1 = dataSet[nonzero(dataSet[:,feature]<=value)[0], :]
	return mat0, mat1

def chooseBestSplit(dataSet, leafType, errType, ops):
	tolS = ops[0]	# desent error value tolerated
	tolN = ops[1]	# minimum number of samples splited
	if len(set(dataSet[:,-1].T.tolist()[0])) == 1 :
		return None, leafType(dataSet)
	N, n = shape(dataSet)
	S = errType(dataSet)
	bestS = inf; bestIndex = 0; bestValue = 0
	for featIndex in range(n-1) :
		for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
			mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
			if (shape(mat0)[0]<tolN) or (shape(mat1)<tolN):
				continue
			newS = errType(mat0) + errType(mat1)
			if newS < bestS:
				bestIndex = featIndex
				bestValue = splitVal
				bestS = newS
	if (S-bestS) < tolS :
		return None, leafType(dataSet)
	mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
	if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN) :
		return None, leafType(dataSet)
	return bestIndex, bestValue

def createTree(dataSet, leafType, errType, ops):
	feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
	if feat == None:
		return val
	retTree = {}
	retTree['spInd'] = feat
	retTree['spVal'] = val
	lSet, rSet = binSplitDataSet(dataSet, feat, val)
	retTree['left'] = createTree(lSet, leafType, errType, ops)
	retTree['right'] = createTree(rSet, leafType, errType, ops)
	return retTree

	
### Post Purning
def isTree(obj):
	return (type(obj).__name__ == 'dict')

def getMean(tree):
	if isTree(tree['right']): 
		tree['right'] = getMean(tree['right'])
	if isTree(tree['left']): 
		tree['left'] = getMean(tree['left'])
	return (tree['right']+tree['left'])/2.0

def postPurning(tree, testData):
	if shape(testData)[0] == 0 :
		return getMean(tree)
	if isTree(tree['right']) or isTree(tree['left']) :
		lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
	if isTree(tree['left']) :
		tree['left'] = postPurning(tree['left'], lSet)
	if isTree(tree['right']) :
		tree['right'] = postPurning(tree['right'], rSet)
	if not isTree(tree['left']) and not isTree(tree['right']) :
		lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
		errNoMerge = sum(power(lSet[:,-1]-tree['left'], 2)) + sum(power(rSet[:,-1]-tree['right'], 2))
		treeMean = (tree['left']+tree['right']) / 2.0
		errMerge = sum(power(testData[:,-1]-treeMean, 2))
		if errMerge < errNoMerge :
			print "merging"
			return treeMean
		else:
			return tree
	else:
		return tree

		
### Predicting
def regTreeEval(model, inData):
	return float(model)
	
def modelTreeEval(model, inData):
	n = shape(inData)[1]
	X = mat(zeros((1,n+1)))
	X[:,1:n+1] = inData
	return float(X*model)

def treeForecast(tree, inData, treeEval):
	if not isTree(tree):
		return treeEval(tree, inData)
	if inData[tree['spInd']] > tree['spVal'] :
		if isTree(tree['left']) :
			return treeForecast(tree['left'], inData, treeEval):
		else:
			return treeEval(tree['left'], inData)
	else:
		if isTree(tree['right']) :
			return treeForecast(tree['right'], inData, treeEval):
		else:
			return treeEval(tree['right'], inData)

def createForecast(tree, testData, treeEval):
	M = len(testData)
	yHat = mat(zeors(M,1))
	for ii in range(M):
		yHat[ii,0] = treeForecast(tree, mat(testData[ii]), treeEval)
	return yHat



  • 10
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值