机器学习实战_09树回归

当数据拥有众多特征且特征之间关系十分复杂时,构建全局模型的想法就显得太难了。(所以,第八章的线性回归不适合)
一种可行的方法是将数据集切分成很多份易建模的数据,
然后利用第8章的线性回归技术来建模。
如果首次切分后仍然难以拟合线性模型就继续切分。在这种切分方式下,树结构和回归法就相当有用。

本章将构建两种树
第一种是回归树,其每个叶节点包含单个值
第二种是模型树,其每个叶节点包含一个线性方程

class treeNode():
    def __init__(self,feat,val,right,left):
        featureToSplitOn = feat
        valueOfsPLIT=val
        rightBranch=right
        leftBranch = left

1. 读取文件

def loadDateSet(fileName):
    dataMat=[]
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine=map(float,curLine)  # 将每行映射成浮点数
        dataMat.append(fltLine)
    return dataMat

2. 辅助函数

划分数据集

# 该函数有3个参数:数据集合、待切分的特征和该特征的某个值。
# 在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回。
def binSplitDataSet(dataSet,feature,value):
    mat0=dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1=dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

生成叶节点

它负责生成叶节点。当chooseBestSplit()函数确定不再对数据进行切分时, 将调用该regLeaf()函数来得到叶节点的模型。
在回归树中,该模型其实就是目标变量的均值。

def regLeaf(dataSet):
    return mean(dataSet[:,-1])

误差估计函数

该函数在给定数据上计算目标变量的平方误差。

def regErr(dataSet):
    return var(dataSet[:,-1])*shape(dataSet)[0]

3. 构造树

是回归树构建的核心函数
伪代码大致如下:

找到最佳的待切分特征:
    如果该节点不能再分,将该节点存为叶节点
    执行二元切分
    在右子树调用createTree()方法
    在左子树调用createTree()方法
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
    # chooseBestSplit
    # 如果满足停止条件,将返回None,和某类模型的值
    # 如果构建的是回归树,该模型是一个常数。如果是模型树,其模型是一个线性方程。
    feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
    if feat ==None:return val
    retTree ={}
    retTree['spInd']=feat
    retTree['spVal']=val
    lSet,rSet = binSplitDataSet(dataSet,feat,val)
    retTree['left']=createTree(lSet,leafType,errType,ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

4. 寻找最优切分点

# 用最佳方式切分数据集和生成相应的叶节点。
对每个特征:
    对每个特征值:
        将数据集切分成两份
        计算切分的误差
        如果当前误差小于当前最小误差, 那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
def  chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
# 其中变量tolS是容许的误差下降值,tolN是切分的最少样本数。
    tolS=ops[0];tolN=ops[1]
# 如果特征值数目为1 , 那么就不需要再切分而直接返回
    if len(set(dataSet[:,-1].T.tolist()[0]))==1:
        return None,leafType(dataSet)
# 计算了当前数据集的大小和误差。该误差S将用于与新切分误差进行对比,来检查新切分能否降低误差
    m,n = shape(dataSet)
    S=errType(dataSet)
    bestS = inf;bestIndex=0;bestValue = 0

# 如果切分数据集后效果提升不够大,那么就不应进行切分操作而直接创建叶节点。
# 检查两个切分后的子集大小,如果某个子集的大小小于用户定义的参数tolN,那么也不应切分。
# 最后,如果这些提前终止条件都不满足,那么就返回切分特征和特征值。

    for featIndex in range(n-1):
        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
            mat0,mat1 = binSplitDataSet(dataSet,featIndex,splitVal)
            if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):continue # 样本数目太少,不切分
            newS=errType(mat0)+errType(mat1)
            if newS < bestS:
                bestS = newS
                bestIndex=featIndex
                bestValue=splitVal
    if(S-bestS) < tolS:
        return None,leafType(dataSet)  # 提升效果不够大,不切分直接创建叶节点
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex,bestValue

5. 测试

myDat = loadDateSet('ex0.txt')
myMat = mat(myDat)
tree = createTree(myMat)
print tree

这里写图片描述
运行的结果

{'spInd': 1, 'spVal': 0.39435, 'right': {'spInd': 1, 'spVal': 0.197834, 'right': -0.023838155555555553, 'left': 1.0289583666666666}, 'left': {'spInd': 1, 'spVal': 0.582002, 'right': 1.980035071428571, 'left': {'spInd': 1, 'spVal': 0.797583, 'right': 2.9836209534883724, 'left': 3.9871631999999999}}}

这里写图片描述

6. 剪枝处理

本章前面巳经进行过剪枝处理。在函数chooseBestSplit()中的提前终止条件,实际上是在进行一种所谓的预剪枝(prepruning)操作。
另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。

预剪枝:
树构建算法其实对输人的参数tolS和tolN非常敏感,如果使用其他值将不太容易达到这么好的效果。
修改一下两个值 (1,4) —> (0,1)

myDat1 = loadDateSet('ex00.txt')
myMat1 = mat(myDat1)
tree1 = createTree(myMat1)
print tree1

tree2 = createTree(myMat1,ops=(0,1))
print tree2

结果相差很大,几乎给每一个值都分配了一个节点

这里写图片描述

后剪枝:
使用后剪枝方法需要将数据集分成测试集和训练集。 首先指定参数, 使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降
低测试误差。如果是的话就合并。
伪代码:

基于已有的树切分测试数据:
    如果存在任一子集是一棵树,则在该子集递归剪枝过程
    计算将当前两个叶节点合并后的误差
    计算不合并的误差
    如果合并会降低误差的话,就将叶节点合并

判断当前节点是否是叶节点

def isTree(obj):
    return (type(obj).__name__=='dict')

函数getMean()是一个递归函数, 它从上往下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。

def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

主函数
它有两个参数:待剪枝的树与剪枝所需的测试数据testData。prune() 函数首先需要确认测试集是否为空0 。一旦非空,则反复递归调用函数prune( ) 对测试数据进行切分。

def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree) 
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)

    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
            sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge: 
            print "merging"
            return treeMean
        else: return tree
    else: return tree

7. 模型树
用树来对数据建模,除了把叶节点简单地设定为常数值之外,有一种方法是把叶节点设定为分段线性函数,这里所谓的分段线性是指模型由多个线性片段组成。
这里写图片描述
该数据实际上是由
这里写图片描述
再加上高斯噪声生成的。

需要修改的代码为:
1. 回归树的误差计算
2. 回归树节点的生成

数据的计算

def linearSolve(dataSet):   #helper function used in two places
    m,n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))#create a copy of data with 1 in 0th postion
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#and strip out Y
    print '++++'
    print X[0:3,1:n]
    print '----'
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

叶节点的生成

def modelLeaf(dataSet):#create linear model and return coeficients
    ws,X,Y = linearSolve(dataSet)
    return ws

数据集误差计算

def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))

测试

myDat1 = loadDateSet('exp2.txt')
myMat1 = mat(myDat1)
tree1 = createTree(myMat1,modelLeaf,modelErr,(1,10))
print tree1

输出结果

{'spInd': 0, 'spVal': 0.285477, 'right': matrix([[ 3.46877936],
        [ 1.18521743]]), 'left': matrix([[  1.69855694e-03],
        [  1.19647739e+01]])}

这里写图片描述
可以看出,已经十分接近了。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值