回归树实现

回归树实现

之前实现过决策树,回归树与决策树的原理类似,可以说他们是双胞胎,决策树用于分类标量型数据,回归树用于回归连续性数据。

不同于决策树,回归树没有信息增益原理,但是也考虑样本数据的纯度,连续型的数据则用方差表示纯度,决策树在以一个属性作为分支后,则不会再考虑该属性,而回归树并不是,回归树以一个属性的一个值作为分支,小于该值的样本和大于等于该值的样本为分支,并且要满足分支后的样本方差和要小于原来的方差。

最终停止分支的条件可以为

  • 分支后的样本数过小
  • 分支后的方差减少过小

代码(有批注):

'''
Created on 2.3, 2020
Decision Tree Source Code 
@author: Yliemevoli
'''
from numpy import *
import matplotlib.pyplot as plt

#数据集的构成是[属性,目标变量]
def loadDataSet(filename):
    dataMat = []
    fr = open(filename)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine)   #返回迭代器将curLine变成float
        dataMat.append(list(fltLine))   #变成list返回
    return dataMat

def binSplitData(dataSet,feature,value):
    #这个函数将dataSet返回两个子集,在feature的属性上小于value和大于value分成两个子集
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1


def regLeaf(dataSet):
    return mean(dataSet[:,-1])
def regErr(dataSet):
    return var(dataSet[:,-1])*shape(dataSet)[0]

def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(0,4)):
    tolS = ops[0];tolN = ops[1] #前者是可以分支的最小方差差,后者是分支后的最小样本数
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None,leafType(dataSet)#只有一个样本可以结束了
    m,n = shape(dataSet)#m=样本个数,n=样本属性个数
    S = errType(dataSet)#整个节点的方差
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):#简单来说就是遍历所有属性的所有可能取值
        data = dataSet[:,featIndex].T.tolist()#取出该属性的所有值
        for splitVal in set(data[0]):#不必计算重复的值
            mat0,mat1 = binSplitData(dataSet,featIndex,splitVal)#以该属性的一个值作为分支
            if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):continue
            newS = errType(mat0)+errType(mat1) #计算分支后的方差和
            if newS < bestS:#筛选出最小
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S-bestS) < tolS:
        return None,leafType(dataSet)
    mat0,mat1 = binSplitData(dataSet,bestIndex,bestValue)
    if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None,leafType(dataSet)
    return bestIndex,bestValue

def creatTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
	dataSet = mat(dataSet)
    feat,val = chooseBestSplit(dataSet,leafType,errType,ops)#当前最佳的分支属性
    if feat == None:return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet,rSet = binSplitData(dataSet,feat,val)
    retTree['left'] = creatTree(lSet,leafType,errType,ops)#递归生成树
    retTree['right'] = creatTree(rSet,leafType,errType,ops)
    return retTree

def isTree(obj):
    return (type(obj).__name__ == 'dict')

def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree, testData):#后剪枝操作
	testData = mat(testData)
    lSet,rSet = binSplitData(testData,tree['spInd'],tree['spVal'])
    if isTree(tree['left']):
        prune(tree['left'],lSet) #如果左儿子是树,则递归下去
    if isTree(tree['right']):
        prune(tree['right'],rSet)   #同理
    if not isTree(tree['left']) and not isTree(tree['right']): #左右儿子是叶子了
        errorNoMerge = sum(power(tree['left']-lSet[:,-1],2)) +\
             sum(power(tree['right']-rSet[:,-1],2))#不合并的方差,tree[left]就是叶子的均值
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1]-treeMean,2))#合并后的方差
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean #如果可以合并,则返回合并后的节点
        else: return tree   #不可以合并,返回原来的树
    else: return tree

参考《机器学习实战》第9章回归树

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值