机器学习笔记十二:分类与回归树CART

本文详细介绍了如何使用CART算法构建回归树,并提供了完整的代码实现。从数据读取、划分数据集、计算误差到选择最佳分裂点,直至构建决策树,每一步都进行了深入解析。此外,还展示了如何进行预测及测试代码的运行结果。
摘要由CSDN通过智能技术生成

Ⅲ.实现
实现部分采用的数据集是机器学习实战中的数据集.代码则是按照自己的理解重新改写了一遍.

读取数据模块:data.py

import numpy as np
def loadData(filename):
    dataSet=np.loadtxt(fname=filename,dtype=np.float32)
    return dataSet
1
2
3
4
用numpy内置的读取txt文件的函数就行,方便快捷.这里就不多讲了.

CART核心模块:CART.py

import numpy as np
import matplotlib.pyplot as plt

#split dataSet trough featureIndex and value
def splitDataSet(dataset,featureIndex,value):
    subDataSet0=dataset[dataset[:,featureIndex]<=value,:]
    subDataSet1=dataset[dataset[:,featureIndex]>value,:]
    return subDataSet0,subDataSet1

#compute the regression Error in a data Set
def getError(dataSet):
    error=np.var(dataSet[:,-1])*dataSet.shape[0]
    return error

#choose the best featureIndex and value in dataSet
def chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit):
    rows,cols=np.shape(dataSet)

    #error in dataSet
    Error=getError(dataSet)

    #init some important value we want get
    bestError=np.inf
    bestFeatureIndex=0
    bestValue=0

    #search process
    #every feature index
    for featureIndex in range(cols-1):
        #every value in dataSet of specific index
        for value in set(dataSet[:,featureIndex]):
            subDataSet0,subDataSet1=splitDataSet(dataSet,featureIndex,value)
            #print("sub0",subDataSet0.shape[0])
            #print("sub1", subDataSet1.shape[0])

          #  print(subDataSet0)
            if (subDataSet0.shape[0]<leastNumOfSplit) or (subDataSet1.shape[0]<leastNumOfSplit):
                continue
            #compute error
            tempError=getError(subDataSet0)+getError(subDataSet1)
            #print("tempError:",tempError)
            if tempError<bestError:
                bestError=tempError
                bestFeatureIndex=featureIndex
                bestValue=value

           # print("BestError:", bestError)
           # print("BestIndex:", bestFeatureIndex)
           # print("BestValue:", bestValue)
    if Error-bestError<leastErrorDescent:
        return None,np.mean(dataSet[:,-1])
    mat0,mat1=splitDataSet(dataSet,bestFeatureIndex,bestValue)
    if (mat0.shape[0]<leastNumOfSplit) or (mat1.shape[0]<leastNumOfSplit):
        return None,np.mean(dataSet[:,-1])

    return bestFeatureIndex,bestValue


#build tree
def buildTree(dataSet,leastErrorDescent=1,leastNumOfSplit=4):
    bestFeatureIndex,bestValue=chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit)

    #recursion termination
    if bestFeatureIndex==None:
        return bestValue

    Tree={}
    Tree["featureIndex"]=bestFeatureIndex
    Tree["value"]=bestValue
    #get subset
    leftSet,rightSet=splitDataSet(dataSet,bestFeatureIndex,bestValue)

    #recursive function
    Tree["left"]=buildTree(leftSet,leastErrorDescent,leastNumOfSplit)
    Tree["right"] = buildTree(rightSet, leastErrorDescent, leastNumOfSplit)

    return Tree

def isTree(tree):
    return (type(tree).__name__=='dict')


def predict(tree,x):
    if x[tree["featureIndex"]]<tree["value"]:
        if isTree(tree["left"]):
            return predict(tree["left"],x)
        else:
            return tree["left"]

    else:
        if isTree(tree["right"]):
            return predict(tree["right"],x)
        else:
            return tree["right"]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
这里一个一个来讲这些函数. 
splitDataSet(dataset,featureIndex,value) 
在理论部分已经讲到,我们要划分数据集,只需要两个值,一个就是特征,另外就是指定的阈值. 
这个函数的作用就是通过传入的特征和阈值,把数据集划分为两部分.理论部分例子的图就可以形象展示这个函数的作用.

getError(dataSet) 
这个函数是用来得到误差的.说是误差,倒不如说是方差.因为理论部分已经给出了式子,其中的c是可以用平均值来替代的,也就是是,刚好是数据集上面的总的方差.

chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit) 
顾名思义,就是找最好的划分罗. 
leastErrorDescent这个参数表示最小的下降误差,也就是说要是在某一刻,误差的下降小于这个值,函数就会退出,leastNumOfSplit表示最小的划分数量.当要划分的集合元素小于这个阈值时候,被认为是没有什么划分的意义了,函数也不会再运行. 
然后函数遍历数据集上面所有的特征,与特征上面的所有值,以找到最好的特征索引和划分点返回.

测试文件:run.py

import numpy as np
import data
import CART

dataMat1=data.loadData("../data/ex00.txt")
dataMat2=data.loadData("../data/ex0.txt")

'''
print(dataMat.shape)
print(np.shape(dataMat))
e=CART.getError(dataMat)
print(e)
print(CART.getError(mat0))
print(CART.getError(mat1))

mat0,mat1=CART.splitDataSet(dataMat,0,0.5)
print(mat0)
print(mat1)
print(mat0.shape)
'''

#bestIndex,bestValue=CART.chooseBestSplit(dataMat)
#print(bestIndex,bestValue)

#tree1
tree1=CART.buildTree(dataMat1)
print(tree1)

#tree2
tree2=CART.buildTree(dataMat2)
print(tree2)

x=[1.0,0.559009]
print(CART.predict(tree2,x))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
用来测试CART回归的运行代码.

import numpy as np
import data
import CART

dataMat1=data.loadData("../data/ex00.txt")
dataMat2=data.loadData("../data/ex0.txt")

'''
print(dataMat.shape)
print(np.shape(dataMat))
e=CART.getError(dataMat)
print(e)
print(CART.getError(mat0))
print(CART.getError(mat1))

mat0,mat1=CART.splitDataSet(dataMat,0,0.5)
print(mat0)
print(mat1)
print(mat0.shape)
'''

#bestIndex,bestValue=CART.chooseBestSplit(dataMat)
#print(bestIndex,bestValue)

#tree1
tree1=CART.buildTree(dataMat1)
print(tree1)

#tree2
tree2=CART.buildTree(dataMat2)
print(tree2)

x=[1.0,0.559009]
print(CART.predict(tree2,x))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
结果: 

--------------------- 
作者:谢小小XH 
来源:CSDN 
原文:https://blog.csdn.net/xierhacker/article/details/64439601 
版权声明:本文为博主原创文章,转载请附上博文链接!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值