四、剪枝
剪枝有预剪枝和后剪枝,预剪枝就是在树生成的过程中,加上一些限制条件使得树不会过度分裂,在上一节代码中,已经加上了预剪枝。
下面重点讲后剪枝。
后剪枝算法:
输入:已经生成的树
输出:剪枝后的树
步骤:
(1)如果存在任一子集是一棵树,则在该子集递归剪枝过程
(2)计算将当前两个叶节点合并后的误差
(3)计算不合并的误差
(4)如果合并会降低误差,则将叶节点合并
def isTree(obj):
return (type(obj).__name__=='dict')
def getMean(tree):
if isTree(tree['right']):tree['right']=getMean(tree['right'])
if isTree(tree['left']):tree['left']=getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
def prune(tree,testData):
if shape(testData)[0]==0:return getMean(tree)
if (isTree(tree['right']) or isTree(tree['left'])):
lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
if isTree(tree['left']):tree['left']=prune(tree['left'],lSet)
if isTree(tree['right']):tree['right']=prune(tree['right'],rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+sum(power(rSet[:,-1]-tree['right'],2))
treeMean=(tree['left']+tree['right'])/2.0
errorMerge=sum(power(testData[:,-1]-treeMean,2))
if errorMerge<errorNoMerge:
print "merging"
return treeMean
else:
return tree
else:
return tree