Ⅲ.实现
实现部分采用的数据集是机器学习实战中的数据集.代码则是按照自己的理解重新改写了一遍.
读取数据模块:data.py
import numpy as np
def loadData(filename):
dataSet=np.loadtxt(fname=filename,dtype=np.float32)
return dataSet
1
2
3
4
用numpy内置的读取txt文件的函数就行,方便快捷.这里就不多讲了.
CART核心模块:CART.py
import numpy as np
import matplotlib.pyplot as plt
#split dataSet trough featureIndex and value
def splitDataSet(dataset,featureIndex,value):
subDataSet0=dataset[dataset[:,featureIndex]<=value,:]
subDataSet1=dataset[dataset[:,featureIndex]>value,:]
return subDataSet0,subDataSet1
#compute the regression Error in a data Set
def getError(dataSet):
error=np.var(dataSet[:,-1])*dataSet.shape[0]
return error
#choose the best featureIndex and value in dataSet
def chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit):
rows,cols=np.shape(dataSet)
#error in dataSet
Error=getError(dataSet)
#init some important value we want get
bestError=np.inf
bestFeatureIndex=0
bestValue=0
#search process
#every feature index
for featureIndex in range(cols-1):
#every value in dataSet of specific index
for value in set(dataSet[:,featureIndex]):
subDataSet0,subDataSet1=splitDataSet(dataSet,featureIndex,value)
#print("sub0",subDataSet0.shape[0])
#print("sub1", subDataSet1.shape[0])
# print(subDataSet0)
if (subDataSet0.shape[0]<leastNumOfSplit) or (subDataSet1.shape[0]<leastNumOfSplit):
continue
#compute error
tempError=getError(subDataSet0)+getError(subDataSet1)
#print("tempError:",tempError)
if tempError<bestError:
bestError=tempError
bestFeatureIndex=featureIndex
bestValue=value
# print("BestError:", bestError)
# print("BestIndex:", bestFeatureIndex)
# print("BestValue:", bestValue)
if Error-bestError<leastErrorDescent:
return None,np.mean(dataSet[:,-1])
mat0,mat1=splitDataSet(dataSet,bestFeatureIndex,bestValue)
if (mat0.shape[0]<leastNumOfSplit) or (mat1.shape[0]<leastNumOfSplit):
return None,np.mean(dataSet[:,-1])
return bestFeatureIndex,bestValue
#build tree
def buildTree(dataSet,leastErrorDescent=1,leastNumOfSplit=4):
bestFeatureIndex,bestValue=chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit)
#recursion termination
if bestFeatureIndex==None:
return bestValue
Tree={}
Tree["featureIndex"]=bestFeatureIndex
Tree["value"]=bestValue
#get subset
leftSet,rightSet=splitDataSet(dataSet,bestFeatureIndex,bestValue)
#recursive function
Tree["left"]=buildTree(leftSet,leastErrorDescent,leastNumOfSplit)
Tree["right"] = buildTree(rightSet, leastErrorDescent, leastNumOfSplit)
return Tree
def isTree(tree):
return (type(tree).__name__=='dict')
def predict(tree,x):
if x[tree["featureIndex"]]<tree["value"]:
if isTree(tree["left"]):
return predict(tree["left"],x)
else:
return tree["left"]
else:
if isTree(tree["right"]):
return predict(tree["right"],x)
else:
return tree["right"]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
这里一个一个来讲这些函数.
splitDataSet(dataset,featureIndex,value)
在理论部分已经讲到,我们要划分数据集,只需要两个值,一个就是特征,另外就是指定的阈值.
这个函数的作用就是通过传入的特征和阈值,把数据集划分为两部分.理论部分例子的图就可以形象展示这个函数的作用.
getError(dataSet)
这个函数是用来得到误差的.说是误差,倒不如说是方差.因为理论部分已经给出了式子,其中的c是可以用平均值来替代的,也就是是,刚好是数据集上面的总的方差.
chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit)
顾名思义,就是找最好的划分罗.
leastErrorDescent这个参数表示最小的下降误差,也就是说要是在某一刻,误差的下降小于这个值,函数就会退出,leastNumOfSplit表示最小的划分数量.当要划分的集合元素小于这个阈值时候,被认为是没有什么划分的意义了,函数也不会再运行.
然后函数遍历数据集上面所有的特征,与特征上面的所有值,以找到最好的特征索引和划分点返回.
测试文件:run.py
import numpy as np
import data
import CART
dataMat1=data.loadData("../data/ex00.txt")
dataMat2=data.loadData("../data/ex0.txt")
'''
print(dataMat.shape)
print(np.shape(dataMat))
e=CART.getError(dataMat)
print(e)
print(CART.getError(mat0))
print(CART.getError(mat1))
mat0,mat1=CART.splitDataSet(dataMat,0,0.5)
print(mat0)
print(mat1)
print(mat0.shape)
'''
#bestIndex,bestValue=CART.chooseBestSplit(dataMat)
#print(bestIndex,bestValue)
#tree1
tree1=CART.buildTree(dataMat1)
print(tree1)
#tree2
tree2=CART.buildTree(dataMat2)
print(tree2)
x=[1.0,0.559009]
print(CART.predict(tree2,x))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
用来测试CART回归的运行代码.
import numpy as np
import data
import CART
dataMat1=data.loadData("../data/ex00.txt")
dataMat2=data.loadData("../data/ex0.txt")
'''
print(dataMat.shape)
print(np.shape(dataMat))
e=CART.getError(dataMat)
print(e)
print(CART.getError(mat0))
print(CART.getError(mat1))
mat0,mat1=CART.splitDataSet(dataMat,0,0.5)
print(mat0)
print(mat1)
print(mat0.shape)
'''
#bestIndex,bestValue=CART.chooseBestSplit(dataMat)
#print(bestIndex,bestValue)
#tree1
tree1=CART.buildTree(dataMat1)
print(tree1)
#tree2
tree2=CART.buildTree(dataMat2)
print(tree2)
x=[1.0,0.559009]
print(CART.predict(tree2,x))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
结果:
---------------------
作者:谢小小XH
来源:CSDN
原文:https://blog.csdn.net/xierhacker/article/details/64439601
版权声明:本文为博主原创文章,转载请附上博文链接!