Chapter9 - Tree-based regression
CART是classification and regression tree,分类与回归树,正如名字所说的,它其实有两种树,分类树和回归树。第三章中讲的决策树是ID3决策树,根据信息增益作为特征选择算法。
CART树与前面说的树有什么差别呢?
1.之前的生成树的算法在对某个特征切分的时候,将数据集按这个特征的所有取值分成很多部分,这样的切分速度太快,而CART只进行二元切分
,对每个特征只切分成两个部分。
2.ID3和C4.5只能处理离散型变量,而CART因为是二元切分,可以处理连续型变量
,而且只要将特征选择的算法改一下的话既可以生成回归树。
本章讲了回归树和模型树
回归树
特征选择
回归树使用的是平方误差最小法作为特征选择算法。
其思想是将当前节点的数据按照某个特征在某个切分点分成两类,比如 R1,R2 ,其对应的类别为 C1,C2 ,我们的任务就是找到一个切分点使误差最小,那么怎么度量误差呢?这里使用的是平方误差,即
min[min∑xi∈R1(yi−c1)2+min∑xi∈R2(yi−c2)2]
遍历某个特征可取的s个切分点(对离散型变量,要么等于要么不等于;对连续型变量,<或者>=),选择使上式最小的切分点。对每个确定的集合,c1,c2取平均值 ∑xi∈R1(yi−c1)2 和 ∑xi∈R2(yi−c2)2 才会最小,这样的话就是求划分为两个集合后,分别对每个集合求方差*实例数,加起来的最小值。
剪枝
简单的剪枝,如果merge后的误差更小就merge
python实现
# 选取最佳分裂特征和值
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0]; tolN = ops[1]
# 全属于一类
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m,n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
# print array(dataSet[:,featIndex].T).tolist()
for splitVal in set(array(dataSet[:,featIndex].T)[0].tolist()):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
# 平方误差
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
# 当新分裂的误差与为分裂的误差小于一个阈值,就不分裂
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex,bestValue
# 创建树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
#简单的剪枝,如果merge后的误差更小就merge
def prune(tree, testData):
if shape(testData)[0] == 0:
return getMean(tree)
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
sum(power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print "merging"
return treeMean
else:
return tree
else: return tree
模型树
树的叶子节点不是一个数值,而是一个模型的参数,如果叶子节点是线性回归模型,那么叶子节点存的就是权值系数w
python实现
# 叶子节点存放的东西
def modelLeaf(dataSet):
ws,X,Y = linearSolve(dataSet)
return ws
#线性模型的误差
def modelErr(dataSet):
ws,X,Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat, 2))
createTree(myMat2, modelLeaf, modelErr,(1,10))