CART 分类回归树、模型树, 及REP后剪枝

本文介绍了CART分类回归树及其在数据处理中的应用,通过实例展示了模型树的构建过程,并详细解析了REP后剪枝的判断准则。内容包括运行结果分析、样本数据散点图展示及剪枝条件解释,强调了剪枝时方差比较的重要性。
摘要由CSDN通过智能技术生成
#coding=utf-8
from numpy import *
def load_data(file_name):
    data_mat=[]
    fr=open(file_name)
    for line in fr.readlines():
        curline=line.strip().split('\t')
        fltline=map(float,curline)
        data_mat.append(fltline)
    return data_mat
def regleaf(data):
    '''平均数'''
    return mean(data[:,-1])
def regerr(data):
    '''方差'''
    return var(data[:,-1])*shape(data)[0]
def split_data_set(data,feature,value):
    '''以样本中某一值分类,大于这个值的为一类,小于等于的为另一类'''
    mat0=data[nonzero(data[:,feature]>value)[0],:]
    #nonzero返回非零元素行列中坐标
    #data[nonzero]可根据坐标返回矩阵中的非零元素
    mat1=data[nonzero(data[:,feature]<=value)[0],:]
    return mat0,mat1
def choose_best_split(data,leaftype=regleaf,errtype=regerr,ops=(1,4)):#ops是自定义值,用于控制函数停止的时机。ops[0]定义了总体方差与平均方差的最小值,ops[1]集合长度。
    '''选择最佳分类'''
    tols=ops[0];toln=ops[1]
    if len(set(data[:,-1].T.tolist()[0]))==1:
        #数据最后一列转成列表,并判断是否只有一个元素
        #如果是就返回None和列表的平均数
        return None,leaftype(data)
    m,n=shape(data)
    s=errtype(data)
    #数据最后一列的方差*行数=总体方差
    min_var=inf;best_index=0;best_value=0
    #min_var=正无穷,初始为正无穷是因为要让第一次判断无论如何都成立,这样才能继续下去
    for feat_index in range(n-1):
        #range(n-1)=0,样本只有两列
        for split_value in set([float(i) for i in data[:,feat_index]]):
            #分类值取自第一列数据的所组成的集合,集合具有互异性
            mat0,mat1=split_data_set(data,feat_index,split_value)
            #
            if (shape(mat0)[0]<toln) or shape(mat1)[0]<toln:continue
            #如果mat0的行数小于4或者mat1的行数小于4,则结束这次循环,下面代码不再执行
            #
            news=errtype(mat0)+errtype(mat1)
            #俩子集方差相加
            if news<min_var:
                #取最小方差,也就是波动最小,相似度较高
                best_index=feat_index
                best_value=split_value
                min_var=news
    if (s-min_var)<tols:
        #总体方差-最小方差
        return None,leaftype(data)
    mat0,mat1=split_data_set(data,best_index,best_value)
    if (shape(mat0)[0]<toln) or (shape(mat1)[0]<toln):
        #如果mat0的行数小于4或者mat1的行数小于4
        return None,leaftype(data)
    return best_index,best_value
    
    
def create_tree(data,leaftype=regleaf,errtype=regerr,ops=(1,4)):
    '''建立回归树'''
    feat,val=choose_best_split(data,leaftype,errtype,ops)
    #分类特征,分类值
    if feat == None:return val
    #如果分类特征为空,则返回分类值。应该理解为样本数据根据单一值就可以很好的划分。
    tree={}
    tree['split_index']=feat
    tree['split_value']=val
    left_set,right_set=split_data_set(data,feat,val)
    #根据最佳分类获得左右两个集合
    #大于的在左,小的在右。。这个可以自己改
    tree['left']=create_tree(left_set,leaftype,errtype,ops)
    tree['right']=create_tree(right_set,leaftype,errtype,ops)
    return tree


def test():
    data=load_data('ex00.txt')
    data=mat(data)
    t=create_tree(data)
    print t
    import matplotlib.pyplot as plt
    fig=plt.figure()
    ax=fig.add_subplot(111)
    x=[float(i) for i in data[:,0]]
    y=[float(i) for i in data[:,1]]
    ax.scatter(x,y)   
    plt.show()
test()

运行结果:{'split_index': 0, 'right': -0.044650285714285719, 'split_value': 0.48813, 'left': 1.0180967672413792}

样本数据散点图:




REP后剪枝:

def is_tree(obj):
    '''判断输入的是否是树'''
    return (type(obj).__name__=='dict')
    #返回布尔型
def get_mean(tree):
    '''求树的平均数'''
    if is_tree(tree['right']):tree['right']=get_mean(tree['right'])
    #tree的右边为树时,就返回树
    if is_tree(tree['left']):tree['left']=get_mean(tree['left'])
    #同上
    return (tree['left']+tree['right'])/2.0

def prune(tree,testdata):
    '''后剪枝'''
    if shape(testdata)[0]==0:return get_mean(tree)
    #判断数据是否为0行,即是否为空
    if (is_tree(tree['right'])or is_tree(tree['left'])):
        #左右有一个为子树时,就根据树的最佳分类特征和分类值,把这个子树分下去
        l_set,r_set=split_data_set(testdata,tree['split_index'],tree['split_value'])
        #根据上个数据的最佳值划分子树
    if is_tree(tree['left']):tree['left']=prune(tree['left'],l_set)
    #如果左集合是子树,就给子树剪枝。结合上面就是说这个函数就是一个递归
    if is_tree(tree['right']):tree['right']=prune(tree['right'],r_set)
    #同上
    if not is_tree(tree['left']) and not is_tree(tree['right']):
        #如果左右两个子集为空
        l_set,r_set=split_data_set(testdata,tree['split_index'],tree['split_value'])
        #根据上个数据的最佳值划分子树
        error_no_merge=sum(power(l_set[:,-1]-tree['left'],2))+sum(power(r_set[:,-1]-tree['right'],2))
        #power函数,power(2,2)=4 即 power(a,b) ==> a**b
        #左方差加右方差
        tree_mean=(tree['left']+tree['right'])/2.0
        #左右子树的平均
        error_merge=sum(power(testdata[:,-1]-tree_mean,2))
        #数据最后一列的方差
        if error_merge<error_no_merge:
            #若总方差小于左右两侧相加的方差,则合并子树
            #这个判断是剪枝的关键
            print 'merging'
            return tree_mean
        else:
            return tree
    else:
        return tree
def test():
    data=load_data('ex0.txt')
    data=mat(data)
    tree=create_tree(data,ops=(0,1))
    print tree
    testdata=load_data('ex2test.txt')
    testdata=mat(testdata)
    t=prune(tree,testdata)
    print '\n'*2
    print t
 REP剪枝需要用新的数据集,原因是如果用旧的数据集,不可能出现分裂后的错误率比分裂前错误率要高的情况。由于使用新的数据集没有参与决策树的构建,能够降低训练数据的影响,降低过拟合的程度,提高预测的准确率。

关键判断:

   if error_merge<error_no_merge:
            #若总方差小于左右两侧相加的方差,则合并子树
            #这个判断是剪枝的关键
            print 'merging'
这里意思就是说,如果下面的子树左右两侧的方差相加大于以此为节点的方差,则合并子树。


盗图:http://www.cnblogs.com/yonghao/p/5064996.html

就是说2,4,5的方差和大于x的方差,则剪枝。

或者4,5 的方差大于3 的,就可把4,5剪掉。




模型树:

用分段直线划分数据。

def linear_solve(data):
    '''求解线性回归系数'''
    m,n=shape(data)
    x=mat(ones((m,n)));y=mat(ones((m,1)))
    x[:,1:n]=data[:,0:n-1];y=data[:,-1]
    #让x的第二列等于data的第一列,y等于data的最后一列
    xtx=x.T*x
    if linalg.det(xtx)==0.0:
        #判断是否为奇异矩阵
        raise NameError('This matrix is singular,cannot do inverse,\n try increasing the second value of ops')
    ws=xtx.I*(x.T*y)
    #求得回归系数
    return ws,x,y
def model_leaf(data):
    #叶节点
    ws,x,y=linear_solve(data)
    return ws
def model_err(data):
    #错误值
    ws,x,y=linear_solve(data)
    yhat=x*ws
    #预测值
    return sum(power((y-yhat),2))
def test():
    data=load_data('exp2.txt')
    data=mat(data)
    tree=create_tree(data,model_leaf,model_err,ops=(1,10))
    print tree
    #testdata=load_data('ex2test.txt')
    #testdata=mat(testdata)
    #t=prune(tree,testdata)
    #print t
    import matplotlib.pyplot as plt
    fig=plt.figure()
    ax=fig.add_subplot(111)
    x=[float(i) for i in data[:,0]]
    y=[float(i) for i in data[:,1]]
    ax.scatter(x,y)   
    plt.show()
test()


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值