《机器学习实战》第九章 树回归 笔记整理

前言

树回归我一开始并没有弄懂它的意思。以前我看过决策树分类,决策树分类实际是基于一个信息学的一个理论香农熵来做的,即找到一个最适合分类的特征设为节点,然后再对其他特征进行分类,已分过类的特征不再考虑。那这个树又怎么做回归呢,回归不是针对连续型变量嘛,而CART算法以及回归树、模型树给了我一个不错的解答。而引入树回归,也是为了解决线性回归需要对全局进行拟合的问题(局部加权线性回归除外),我们通过树对数据进行划分,形成易于使用通用回归算法进行拟合的一块一块的数据。

CART算法

CART算法全称为classification and regression tree algorithm。CART使用的是二元切分法,是一个二叉树:

  • 待切分的特征
  • 待切分的特征值
  • 右子树
  • 左子树

思路非常简单:createTree()

  • 通过某个标准找到最佳待切分特征
  • 如果该节点不能再分存为叶节点
  • 执行二元切分
  • 在右子树调用createTree()
  • 在左子树调用createTree()

在CART算法的实现代码中,有一些需要注意的:

  1. 加载数据
fltLine = map(float,curLine)#python3这里返回的是一个迭代器而不是一个list
dataMat.append(list(fltLine))#需要用list把迭代器转换成列表
  1. 切分数据
mat0 = dataSet[nonzero(dataSet[:,feature]>value)[0],:]#使用了nonzero
  1. 创建CART
def createTree(dataSet,leafType = regLeaf,errType = regErr,ops = (1,4)):
	feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
	if feat == None: return val
	retTree = {}
	retTree['spInd'] = feat
	retTree['spVal'] = val
	lSet,rSet = binSplitDataSet(dataSet, feat, val)
	retTree['left'] = createTree(lSet,leafType,errType,ops)
	retTree['right'] = createTree(rSet,leafType,errType,ops)
	return retTree

CART算法就是一个统称吧,具体还需要针对具体的树,关键就是如何选择最好切分,以及叶节点的模型

回归树

回归树的叶节点模型就是目标变量的均值,它的数据混乱度的计算实际引入了我们一直用的总方差(平方误差的总值)。从chooseBestSplit函数我们就能够看出来这个回归的原理,实际上这个回归预测全在叶节点上。用户定义的tolS和tolN决定了这个模型的简单或者复杂。

def regLeaf(dataSet):
	return mean(dataSet[:,-1])

def regErr(dataSet):
	return var(dataSet[:,-1])*shape(dataSet)[0]
	
def chooseBestSplit(dataSet,leafType = regLeaf,errType = regErr,ops=(1,4)):
	tolS = ops[0];tolN = ops[1]
	if len(set(dataSet[:,-1].T.tolist()[0]))==1:#用于降低复杂度
		return None,leafType(dataSet)
	m,n = shape(dataSet)
	S = errType(dataSet)
	bestS = inf;bestIndex = 0;bestValue = 0
	for featIndex in range(n-1):
		for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):
			mat0,mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
			if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN): continue
			newS = errType(mat0) + errType(mat1)
			if newS<bestS:
				bestIndex = featIndex
				bestValue = splitVal
				bestS = newS
	if (S - bestS)<tolS:
		return None,leafType(dataSet)
	mat0,mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
	if(shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
		return None,leafType(dataSet)
	return bestIndex,bestValue

直接对每个特征的每个特征值进行遍历,直到找到最好的分割,紧接着和tolS比较看值不值进行分割。这个算法易于理解,也比较实在,最后的拟合结果在二维上就是分段直线函数的形式:
在这里插入图片描述

剪枝

对于回归树的一个优化方法是剪枝,剪枝分为预剪枝和后剪枝。我们设置的tolS和tolN提前退出就是预剪枝,但是这种剪枝极其依赖经验,因此又提出使用测试集进行后剪枝。
先创建一个足够大的树,从上而下遍历,如果左子树或者右子树不为叶,则在子树上继续递归剪枝过程,直到两个都是叶节点,开始合并,看是否误差减小(因为只有两个都是叶的时候误差才能计算,而且最低层是效果最差的分割了,如果这样还没有过拟合,那更好的分割肯定也没有过拟合)。这就是后剪枝的算法流程。误差是通过以下式子表示:
Σ ( t e s t 左 − t r e e 左 ) 2 + Σ ( t e s t 右 − t r e e 右 ) 2 \Sigma(test_左 - tree_左)^2+\Sigma(test_右 - tree_右)^2 Σ(testtree)2+Σ(testtree)2

def isTree(obj):
	return (type(obj).__name__=='dict')

def getMean(tree):
	if isTree(tree['right']): tree['right'] = getMean(tree['right'])
	if isTree(tree['left']): tree['left'] = getMean(tree['left'])
	return (tree['left']+tree['right'])/2.0

def prune(tree,testData):
	if shape(testData)[0] == 0: return getMean(tree)
	if (isTree(tree['right']) or isTree(tree['left'])):
		lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])

	if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
	if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
	if not(isTree(tree['right'])) and not(isTree(tree['left'])):
		lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
		errorNoMerge = sum(power(lSet[:,-1]-tree['left'],2))+sum(power(rSet[:,-1]-tree['right'],2))
		treeMean = (tree['left']+tree['right'])/2.0
		errorMerge = sum(power(testData[:,-1]-treeMean,2))
		if errorMerge<errorNoMerge:
			print("merging")
			return treeMean
		else:
			return tree
	else: return tree

模型树

模型树和局部加权线性模型很像,因为在模型树中每个叶节点都是一个线性回归模型,它可以不错的拟合多个线性片段。为了实现模型树,它的分割误差计算和回归树的总方差不同了,因为回归树的叶节点就是平均值,对于模型树的叶节点,需要先利用分割后的数据计算两个节点的回归模型,然后使用这个模型计算数据的预测值,再和真实值进行差值平方求和作为误差。找到最小误差,便找到最佳分割。线性回归模型也是使用第八章最简单的线性模型。

def linearSolve(dataSet):
	m,n = shape(dataSet)
	X = mat(ones((m,n)));Y = mat(ones)
	X[:,1:n] = dataSet[:,0:n-1];Y = dataSet[:,-1]
	xTx = X.T*X
	if linalg.det(xTx) == 0.0:
		raise NameError('This matrix is singular, cannot do inverse,\n\
		try increasing the second value of ops')
	ws = xTx.I*(X.T*Y)
	return ws,X,Y   #线性回归标准做法

def modelLeaf(dataSet):
	ws,X,Y = linearSolve(dataSet)
	return ws

def modelErr(dataSet):
	ws,X,Y = linearSolve(dataSet)
	yHat = X*ws
	return sum(power(Y-yHat,2))

二维数据拟合结果如下:
在这里插入图片描述
我们可以看到模型树和回归树明显的差别。

比较模型的优劣 R 2 R^2 R2

利用协方差来比较模型的优劣,因为协方差就是刻画数据之间线性相关性的一个参数,学过概率论和数理统计应该很理解这件事。

##预测代码
def regTreeEval(model, inDat):
    return float(model)

def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1]=inDat
    return float(X*model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
        else: return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
        else: return modelEval(tree['right'], inData)

def createForeCast(tree,testData,modelEval = regTreeEval):
	m = len(testData)
	yHat = mat(zeros((m,1)))
	for i in range(m):
		yHat[i,0] = treeForeCast(tree, mat(testData[i]),modelEval)
	return yHat

利用GUI定性显示模型优劣

这里我就直接放代码了,因为我也没有深究tkinter和matplotlib的使用。
说明一下,这里适用于python3,python2还请去看书中附的代码。

from numpy import*
from tkinter import*
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import regTrees
def reDraw(tolS,tolN):
	reDraw.f.clf()
	reDraw.a = reDraw.f.add_subplot(111)
	if chkBtnVar.get():
		if tolN < 2: tolN = 2
		myTree = regTrees.createTree(reDraw.rawDat,regTrees.modelLeaf,regTrees.modelErr,(tolS,tolN))
		yHat = regTrees.createForeCast(myTree, reDraw.testDat,regTrees.modelTreeEval)
	else:
		myTree = regTrees.createTree(reDraw.rawDat,ops=(tolS,tolN))
		yHat = regTrees.createForeCast(myTree, reDraw.testDat)
	reDraw.a.scatter(reDraw.rawDat[:,0].tolist(), reDraw.rawDat[:,1].tolist(), s=5) #use scatter for data set
	reDraw.a.plot(reDraw.testDat,yHat,linewidth=2.0)
	reDraw.canvas.show()

def getInputs():
	try:tolN = int(tolNentry.get())
	except:
		tolN = 10
		print("enter Integer for tolN")
		tolNentry.delete(0,END)
		tolNentry.insert(0, '10')
	try:tolS = float(tolSentry.get())
	except:
		tolS = 1.0
		print("enter float for tolS")
		tolSentry.delete(0,END)
		tolSentry.insert(0, '1.0')
	return tolN,tolS

def drawNewTree():
	tolN,tolS = getInputs()
	reDraw(tolS, tolN)

root = Tk()

reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

#Label(root,text = "Plot Place Holder").grid(row=0,columnspan = 3)

Label(root,text = "tolN").grid(row = 1,column = 0)
tolNentry = Entry(root)
tolNentry.grid(row = 1,column = 1)
tolNentry.insert(0,'10')

Label(root,text = "tolS").grid(row = 2,column = 0)
tolSentry = Entry(root)
tolSentry.grid(row = 2,column = 1)
tolSentry.insert(0,'1.0')

Button(root,text = "ReDraw",command = drawNewTree).grid(row = 1,column = 2,rowspan = 3)

Button(root,text = "Quit",fg = "black",command = root.quit()).grid(row = 1,column = 2)

chkBtnVar = IntVar()
chkBtn = Checkbutton(root,text = "Model Tree",variable = chkBtnVar)
chkBtn.grid(row = 3,column = 0,columnspan = 2)
reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]), max(reDraw.rawDat[:,0]), 0.01)
reDraw(1.0,10)
root.mainloop()

40

30

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值