机器学习实战(九)


title: 机器学习实战(九)
date: 2020-05-01 09:20:50
tags: [树回归, CSRT算法, 树剪枝算法]
categories: 机器学习实战
更多内容请关注我的博客

数回归

分类回归树 Classification And Regression Trees 分类回归树。该算法既可以用于回归还可以用于分类。

复杂数据的局部性建模

数回归

优点:可以对复杂和线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据

第三章使用的树构建的算法是ID3。ID3的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来切分。也就取值,那么数据将被切分成4份,一但按某种特征切分后,该特征在之后的算法执行过程中将不会再起作用,所以有观点认为这种切分方式过于迅速。另一种方法是二元切分发,即每次吧数据集切分成两份。如果数据的某个特征等于切分所要求的值,那么这些数据就进入树的左子树,反之则进入树的右子树。

除了切分过于迅速外,ID3算法还存在另一个问题,它不能直接处理连续型特征。只有事先将连续型特征转换成离散型,才能在ID3算法中使用。但这种转换过程会破坏连续型变量的内在性质。而使用二元切分法则对于树构建过程进行调整以处理连续型特征。

具体处理方法是:

如果特征值大于给定值就走左子树,否则就走右子树。

另外,二元切分法也节省了树的构建时间,但这点意义也不是特别大因为这些树的构建一般是离线完成的。

CART是十分著名且广泛记载的树构建算法,它使用二元切分来处理连续型变量。对CART稍作修改就可以处理回归问题。

回归树的一般方法:

  1. 收集数据
  2. 准备数据:需要数值型的数据,标称型数据应该映射成二值型数据
  3. 分析数据:绘出数据的二维可视化显示结果,以字典方式生成树
  4. 训练算法:大部分时间都花费在叶节点树模型的构建上
  5. 测试算法:使用测试数据上的R^2值来分析模型的效果
  6. 使用算法:使用训练出的树做预测

连续和离散型特征的树的构建

在树的构建过程中,需要解决多种类型数据的存储问题。这里将使用字典来存储树的数据结构,该字典将包含以下四种元素。

待切分的特征
待切分的特征值
右子树。当不需要切分时,也可以是单值
左子树。与右子树类似

CART算法只做二元切分,所以这里可以固定树的数据结构。树包含左键和右键,可以存储另一颗树或者单个值。字典还包含特征和特征值这两个键,它们给出的切分算法所有的特征和特征值。

接下来构建两种树,第一种是回归树(regression tree)其中每个叶节点包含单个值,第二种是是模型树(model tree)其中每个叶节点包含一个线性方程。

createTree()的伪代码大致如下:

找到最佳的待切分特征:
    如果该节点不能再分,将该节点存为叶节点
    执行二元切分
    在右子树调用createTree()方法
    在左子树调用createTree()方法
from numpy import *
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float, curLine)
        fltLine = list(fltLine)
        dataMat.append(fltLine)
    fr.close()
    return dataMat

def regLeaf(dataSet): # returns the value used for each leaf
    return mean(dataSet[:,-1])

def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    # 满足条件时返回叶节点值
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1
testMat = mat(eye(4))
testMat
matrix([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)
mat0
matrix([[0., 1., 0., 0.]])
mat1
matrix([[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
nonzero(testMat[:, 1] > 0.5)[0][0]
1

将CART算法用于回归

要对数据的复杂关系建模,我们已经决定借用树结构来帮助切分数据,那么如何实现数据的切分呢?

为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。在给定的节点计算数据的混乱度,计算数据混乱度的方法,首先计算所有数据的均值,然后计算每条数据的值到均值的差值,为了对正负差值同等看待,一般使用绝对值或平方值来代替上述差值。类似方差的计算,唯一不同是方差是平方误差的均值,而这里需要的是平方误差的总值,总方差可以通过均方差乘以数据集中样本点的个数来得到。

构建树

构建回归树,需要补充一些新的代码,首先要做的就是实现chooseBestSplit()函数,给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。另外该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。因此chooseBestSplit()函数需要完成两件事:用最佳方式切分数据集和生成相应的叶节点。

leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,而ops是一个用户定义的参数构成的元组,用已完成树的构建。

伪代码如下:

对每个特征:
    对每个特征值:
        将数据集切分成两份
        计算切分的误差
        如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    tolS = ops[0]
    tolN = ops[1]

    if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 如果所有值相等则退出
        return None, leafType(dataSet)
    
    m, n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf
    bestIndex = 0
    bestValue = 0
    
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:, featIndex].tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN) :
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS

    if (S - bestS) < tolS: # 如果误差减小不大则退出
        return None, leafType(dataSet)
    
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 如果切分出的数据集很小则退出
        return None, leafType(dataSet)
    
    return bestIndex, bestValue
    

regLeaf()它负责生成叶节点。当chooseBestSplit()函数确定不再对数据进行切分时,将调用该regLeaf()函数来得到叶节点的模型,在回归树中,该模型就是目标变量的均值。

regErr()是误差估计函数,该函数在给定数据上计算目标变量的平方误差,当然也可以先计算出均值,然后计算每个差值再平方。因为这里需要总方差,所以用均方差函数var()的结果乘以数据集中的样本个数。

chooseBestSplit()该函数的目的是找到数据的最佳二元切分方式。如果找不到一个好的二元切分,该函数返回None并同时调用createTree()方法来产生叶节点,叶节点的值也将返回None。ops设定了tolS和tolN两个值,tolS是容许的误差下降值,tolN是切分的最少样本数。

运行代码

myDat = loadDataSet('MLiA_SourceCode/Ch09/ex00.txt')
myMat = mat(myDat)
createTree(myMat)
{'spInd': 0,
 'spVal': 0.036098,
 'left': 0.5878577680412371,
 'right': 0.050698999999999994}
import matplotlib.pyplot as plt
def plotScatter(data):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    #print(myMat)
    ax.scatter(data[:, 0].T.tolist()[0], data[:, 1].T.tolist()[0], 5, c='red')
    plt.show()
plotScatter(myMat)

在这里插入图片描述

myDat = loadDataSet('MLiA_SourceCode/Ch09/ex0.txt')
myMat = mat(myDat)
createTree(myMat)
{'spInd': 1,
 'spVal': 0.409175,
 'left': {'spInd': 1,
  'spVal': 0.663687,
  'left': {'spInd': 1,
   'spVal': 0.725426,
   'left': 3.7206952592592595,
   'right': 2.998615611111111},
  'right': 2.2076016800000002},
 'right': 0.45470547435897446}
plotScatter(myMat[:, 1:])

在这里插入图片描述

树剪枝

通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。

预剪枝

树构建算法其实对输入的参数tolS和tolN非常敏感,对ops参数调整就是预剪枝。

后剪枝

利用测试集来对树进行剪枝,不需要用户指定参数,为后剪枝。

prune()伪代码如下

基于已有的树切分测试数据:
    如果存在任一子集是一棵树,则在该子集递归剪枝过程
    计算将当前两个叶节点合并后的误差
    计算不合并的误差
    如果合并降低误差,就将叶节点合并
def isTree(obj):
    return (type(obj).__name__=='dict')

def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0
    
def prune(tree, testData):
    if shape(testData)[0] == 0:
        return getMean(tree)  # if we have no test data collapse the tree
    if (isTree(tree['right']) or isTree(tree['left'])): # if the branches are not trees try to prune them
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
    # if they are now both leafs, see if we can merge them
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
            sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge: 
            return treeMean
        else: 
            return tree
    else: 
        return tree
myDat2 = loadDataSet('MLiA_SourceCode/Ch09/ex2.txt')
myMat2 = mat(myDat2)
myTree = createTree(myMat2, ops=(0, 1))
myDatTest = loadDataSet('MLiA_SourceCode/Ch09/ex2test.txt')
myMat2Test = mat(myDatTest)
prune(myTree, myMat2Test)
{'spInd': 0,
 'spVal': 0.228628,
 'left': {'spInd': 0,
  'spVal': 0.965969,
  'left': 92.5239915,
  'right': 65.53919801898735},
 'right': -1.1055498250000002}

模型树

用树来对数据建模,除了吧叶节点简单地设定为常数值之外,还有一种方法是把叶节点设定为分段线性函数,这里所谓的分段性(piecewise linear)是指模型由多个线性片段组成。

决策树相比其他机器学习算法的优势之一在于结果更易理解。很显然,两条直线比很多节点组成一棵大树更容易解释。模型树的可解释性是它优于回归树的特点之一。另外,模型树也具有更高的预测准确度。

def linearSolve(dataSet):
    m, n = shape(dataSet)
    X = mat(ones((m, n)))
    Y = mat(ones((m, 1)))
    X[:, 1:n] = dataSet[:, 0:n-1]
    Y = dataSet[:, -1]
    xTx = X.T*X
    
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\
                         try increasing the second value of ops')
    
    ws = xTx.I * (X.T * Y)
    return ws, X, Y

def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    
def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat, 2))
createTree(myMat2, modelLeaf, modelErr, (1, 10))
{'spInd': 0, 'spVal': 0.228628, 'left': None, 'right': None}

总结

这一章提供的code有很多错误,修正后并不能得到书中的答案。如果要使用树算法,还是建议使用sklearn,而非自己编写。

CART算法可以用于构建二元树并处理离散型或连续型数据的切分。若使用不同的误差准则,就可以通过CART算法构建模型树和回归树。该算法构建出的树会倾向于对数据过拟合。过拟合的树十分复杂,剪枝可以解决这个问题。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值