当数据拥有众多特征且特征之间关系十分复杂时,构建全局模型的想法就显得太难了。(所以,第八章的线性回归不适合)
一种可行的方法是将数据集切分成很多份易建模的数据,
然后利用第8章的线性回归技术来建模。
如果首次切分后仍然难以拟合线性模型就继续切分。在这种切分方式下,树结构和回归法就相当有用。
本章将构建两种树
第一种是回归树,其每个叶节点包含单个值
第二种是模型树,其每个叶节点包含一个线性方程
class treeNode():
def __init__(self,feat,val,right,left):
featureToSplitOn = feat
valueOfsPLIT=val
rightBranch=right
leftBranch = left
1. 读取文件
def loadDateSet(fileName):
dataMat=[]
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine=map(float,curLine) # 将每行映射成浮点数
dataMat.append(fltLine)
return dataMat
2. 辅助函数
划分数据集
# 该函数有3个参数:数据集合、待切分的特征和该特征的某个值。
# 在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回。
def binSplitDataSet(dataSet,feature,value):
mat0=dataSet[nonzero(dataSet[:,feature] > value)[0],:]
mat1=dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
return mat0,mat1
生成叶节点
它负责生成叶节点。当chooseBestSplit()函数确定不再对数据进行切分时, 将调用该regLeaf()函数来得到叶节点的模型。
在回归树中,该模型其实就是目标变量的均值。
def regLeaf(dataSet):
return mean(dataSet[:,-1])
误差估计函数
该函数在给定数据上计算目标变量的平方误差。
def regErr(dataSet):
return var(dataSet[:,-1])*shape(dataSet)[0]
3. 构造树
是回归树构建的核心函数
伪代码大致如下:
找到最佳的待切分特征:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
# chooseBestSplit
# 如果满足停止条件,将返回None,和某类模型的值
# 如果构建的是回归树,该模型是一个常数。如果是模型树,其模型是一个线性方程。
feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
if feat ==None:return val
retTree ={}
retTree['spInd']=feat
retTree['spVal']=val
lSet,rSet = binSplitDataSet(dataSet,feat,val)
retTree['left']=createTree(lSet,leafType,errType,ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
4. 寻找最优切分点
# 用最佳方式切分数据集和生成相应的叶节点。
对每个特征:
对每个特征值:
将数据集切分成两份
计算切分的误差
如果当前误差小于当前最小误差, 那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
# 其中变量tolS是容许的误差下降值,tolN是切分的最少样本数。
tolS=ops[0];tolN=ops[1]
# 如果特征值数目为1 , 那么就不需要再切分而直接返回
if len(set(dataSet[:,-1].T.tolist()[0]))==1:
return None,leafType(dataSet)
# 计算了当前数据集的大小和误差。该误差S将用于与新切分误差进行对比,来检查新切分能否降低误差
m,n = shape(dataSet)
S=errType(dataSet)
bestS = inf;bestIndex=0;bestValue = 0
# 如果切分数据集后效果提升不够大,那么就不应进行切分操作而直接创建叶节点。
# 检查两个切分后的子集大小,如果某个子集的大小小于用户定义的参数tolN,那么也不应切分。
# 最后,如果这些提前终止条件都不满足,那么就返回切分特征和特征值。
for featIndex in range(n-1):
for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
mat0,mat1 = binSplitDataSet(dataSet,featIndex,splitVal)
if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):continue # 样本数目太少,不切分
newS=errType(mat0)+errType(mat1)
if newS < bestS:
bestS = newS
bestIndex=featIndex
bestValue=splitVal
if(S-bestS) < tolS:
return None,leafType(dataSet) # 提升效果不够大,不切分直接创建叶节点
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex,bestValue
5. 测试
myDat = loadDateSet('ex0.txt')
myMat = mat(myDat)
tree = createTree(myMat)
print tree
运行的结果
{'spInd': 1, 'spVal': 0.39435, 'right': {'spInd': 1, 'spVal': 0.197834, 'right': -0.023838155555555553, 'left': 1.0289583666666666}, 'left': {'spInd': 1, 'spVal': 0.582002, 'right': 1.980035071428571, 'left': {'spInd': 1, 'spVal': 0.797583, 'right': 2.9836209534883724, 'left': 3.9871631999999999}}}
6. 剪枝处理
本章前面巳经进行过剪枝处理。在函数chooseBestSplit()中的提前终止条件,实际上是在进行一种所谓的预剪枝(prepruning)操作。
另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。
预剪枝:
树构建算法其实对输人的参数tolS和tolN非常敏感,如果使用其他值将不太容易达到这么好的效果。
修改一下两个值 (1,4) —> (0,1)
myDat1 = loadDateSet('ex00.txt')
myMat1 = mat(myDat1)
tree1 = createTree(myMat1)
print tree1
tree2 = createTree(myMat1,ops=(0,1))
print tree2
结果相差很大,几乎给每一个值都分配了一个节点
后剪枝:
使用后剪枝方法需要将数据集分成测试集和训练集。 首先指定参数, 使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降
低测试误差。如果是的话就合并。
伪代码:
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差
计算不合并的误差
如果合并会降低误差的话,就将叶节点合并
判断当前节点是否是叶节点
def isTree(obj):
return (type(obj).__name__=='dict')
函数getMean()是一个递归函数, 它从上往下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
主函数
它有两个参数:待剪枝的树与剪枝所需的测试数据testData。prune() 函数首先需要确认测试集是否为空0 。一旦非空,则反复递归调用函数prune( ) 对测试数据进行切分。
def prune(tree, testData):
if shape(testData)[0] == 0: return getMean(tree)
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
sum(power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print "merging"
return treeMean
else: return tree
else: return tree
7. 模型树
用树来对数据建模,除了把叶节点简单地设定为常数值之外,有一种方法是把叶节点设定为分段线性函数,这里所谓的分段线性是指模型由多个线性片段组成。
该数据实际上是由
再加上高斯噪声生成的。
需要修改的代码为:
1. 回归树的误差计算
2. 回归树节点的生成
数据的计算
def linearSolve(dataSet): #helper function used in two places
m,n = shape(dataSet)
X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
print '++++'
print X[0:3,1:n]
print '----'
xTx = X.T*X
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse,\n\
try increasing the second value of ops')
ws = xTx.I * (X.T * Y)
return ws,X,Y
叶节点的生成
def modelLeaf(dataSet):#create linear model and return coeficients
ws,X,Y = linearSolve(dataSet)
return ws
数据集误差计算
def modelErr(dataSet):
ws,X,Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat,2))
测试
myDat1 = loadDateSet('exp2.txt')
myMat1 = mat(myDat1)
tree1 = createTree(myMat1,modelLeaf,modelErr,(1,10))
print tree1
输出结果
{'spInd': 0, 'spVal': 0.285477, 'right': matrix([[ 3.46877936],
[ 1.18521743]]), 'left': matrix([[ 1.69855694e-03],
[ 1.19647739e+01]])}
可以看出,已经十分接近了。