1.CART回归树代码
from numpy import *
#加载数据集
def load_data(fileName):
dataset = []
fr = open(fileName).readlines()
for line in fr:
curline = line.strip().split('\t')
fltline = list(map(float, curline)) #加上list就是完整功能了,不然只用map会生成内存地址
dataset.append(fltline)
dataset = array(dataset) #转化为array数组,便于计算
return dataset
#计算平均值,负责生成叶子节点,叶子节点根据公式就是求所在子区域的平均值
def regLeaf(dataset):
return mean(dataset[:, -1])
#计算损失函数中的总方差=方差x样本个数=每个点与估计的差(叶节点的均值)的平方和
def regErr(dataset):
return var(dataset[:, -1]) * shape(dataset)[0]
#根据特征索引和分裂点划分数据集为左右两部分
def splitDataset(dataset, index_of_feature, value):
mat0 = dataset[nonzero(dataset[:, index_of_feature] <= value)[0], :]
mat1 = dataset[nonzero(dataset[:, index_of_feature] > value)[0], :]
return mat0, mat1
#选择最佳特征和相应的划分点
def choose_best_feature(dataset, leafType=regLeaf, errType=regErr, ops=(1, 4)):
min_error = ops[0] #最小下降损失
min_samples = ops[1] #最小样本数
#情形1
if len(set(dataset[:, -1])) == 1: #停止切分的条件之一,数据集中都是同一个目标值
return None, leafType(dataset) #返回最佳划分特征为None,切分点为 均值
m, n = shape(dataset) #统计数据集的行和列
Loss = errType(dataset) #损失函数
bestLoss = inf #最优切分点的误差
bestIndex = 0 #最优特征索引
bestValue = 0 #最优切分点的值
for featureIndex in range(n-1): #遍历特征
for value in set(dataset[:, featureIndex]):
mat0, mat1 = splitDataset(dataset, featureIndex, value)
if(shape(mat0)[0] < min_samples) or (shape(mat1)[0] < min_samples): #若两个子区域的叶子节点小于4个样本,则跳过此划分点
continue
curLoss = errType(mat0) + errType(mat1)
if curLoss < bestLoss:
bestLoss = curLoss
bestIndex = featureIndex
bestValue = value
if Loss - bestLoss < min_error: #若切分之前的损失-切分之后的损失< min_error,那么就是损失减少不明显,停止切分
return None, leafType(dataset)
mat0, mat1 = splitDataset(dataset, featureIndex, value)
if(shape(mat0)[0] < min_samples) or (shape(mat1)[0] < min_samples):
return None, leafType(dataset)
return bestIndex, bestValue
#生成回归树
def generate_tree(dataset, leafType=regLeaf, errType=regErr, ops=(1,4)):
feature_index, value = choose_best_feature(dataset, leafType, errType, ops) #最优切分点划分
if feature_index == None:
return value
reTree = {}
reTree['spInd'] = feature_index #最优切分特征的索引
reTree['spVal'] = value #最优切分点的值
lSet, rSet = splitDataset(dataset, feature_index, value)
reTree['left'] = generate_tree(lSet, leafType, errType, ops)
reTree['right'] = generate_tree(rSet, leafType, errType, ops)
return reTree
#-------剪枝处理--------
#判断是叶子还是树
def isTree(obj):
return (type(obj).__name__ == 'dict')
#获取左右叶子的平均值
def getMean(reTree):
if isTree(reTree['left']):
reTree['left'] = getMean(reTree['left'])
if isTree(reTree['right']):
reTree['right'] = getMean(reTree['right'])
return (reTree['left'] + reTree['right']) / 2.0
#后剪枝
def postpruning(reTree, test_data):
if shape(test_data)[0] == 0:
print("判断测试集为空,执行过吗?")
return getMean(reTree)
if (isTree(reTree['left']) or isTree(reTree['right'])):
lSet, rSet = splitDataset(test_data, reTree['spInd'], reTree['spVal'])
if isTree(reTree['left']):
reTree['left'] = postpruning(reTree['left'], lSet)
if isTree(reTree['right']):
reTree['right'] = postpruning(reTree['right'], rSet)
if not isTree(reTree['left']) and not isTree(reTree['right']):
lSet, rSet = splitDataset(test_data, reTree['spInd'], reTree['spVal'])
errorNoMerge = sum(power(lSet[:, -1] - reTree['left'], 2)) + sum(power(rSet[:, -1] - reTree['right'], 2)) #剪枝前的误差
treeMean = (reTree['left'] + reTree['right']) / 2.0
errorMerge = sum(power(test_data[:, -1] - treeMean, 2)) #剪枝后的误差
if errorMerge < errorNoMerge:
print("剪枝")
return treeMean #剪枝后的节点值为原来左右叶子结点的均值
else:
return reTree #不剪枝
else:
return reTree
#回归树预测一个样本
def reTree_predict_one_test(reTree, one_test_example):
first_feature_index = reTree[list(reTree.keys())[0]] #获取第一个结点的特征索引
feature_spVal = reTree[list(reTree.keys())[1]] #获取第一个结点的分裂点
predict_val = 0.0
if one_test_example[first_feature_index] <= feature_spVal:
if type(reTree['left']).__name__ == 'dict':
predict_val = reTree_predict_one_test(reTree['left'], one_test_example)
else:
predict_val = reTree['left']
else:
if type(reTree['right']).__name__ == 'dict':
predict_val = reTree_predict_one_test(reTree['right'], one_test_example)
else:
predict_val = reTree['right']
return predict_val
#回归树预测所有测试数据结果
def reTree_predict(reTree, test_data):
classLabel = []
for one_test in test_data:
classLabel.append(reTree_predict_one_test(reTree, one_test))
return classLabel
if __name__ == '__main__':
myData = load_data('./ex2.txt')
reTree = generate_tree(myData)
reTree_str = str(reTree)
print("未剪枝之前的决策树:")
print(reTree_str)
test_example = [0.8935465]
predict_val = reTree_predict_one_test(reTree, test_example)
print(predict_val)
print("剪枝之后的决策树:")
testData = load_data('./ex2test.txt')
postPrunTree = postpruning(reTree, testData)
print(postPrunTree)
predict_val = reTree_predict_one_test(postPrunTree, test_example)
print(predict_val)
2.运行结果
3.数据集
3.1 ex2.txt
0.409175 1.883180
0.182603 0.063908
0.663687 3.042257
0.517395 2.305004
0.013643 -0.067698
0.469643 1.662809
0.725426 3.275749
0.394350 1.118077
0.507760 2.095059
0.237395 1.181912
0.057534 0.221663
0.369820 0.938453
0.976819 4.149409
0.616051 3.105444
0.413700 1.896278
0.105279 -0.121345
0.670273 3.161652
0.952758 4.135358
0.272316 0.859063
0.303697 1.170272
0.486698 1.687960
0.511810 1.979745
0.195865 0.068690
0.986769 4.052137
0.785623 3.156316
0.797583 2.950630
0.081306 0.068935
0.659753 2.854020
0.375270 0.999743
0.819136 4.048082
0.142432 0.230923
0.215112 0.816693
0.041270 0.130713
0.044136 -0.537706
0.131337 -0.339109
0.463444 2.124538
0.671905 2.708292
0.946559 4.017390
0.904176 4.004021
0.306674 1.022555
0.819006 3.657442
0.845472 4.073619
0.156258 0.011994
0.857185 3.640429
0.400158 1.808497
0.375395 1.431404
0.885807 3.935544
0.239960 1.162152
0.148640 -0.227330
0.143143 -0.068728
0.321582 0.825051
0.509393 2.008645
0.355891 0.664566
0.938633 4.180202
0.348057 0.864845
0.438898 1.851174
0.781419 2.761993
0.911333 4.075914
0.032469 0.110229
0.499985 2.181987
0.771663 3.152528
0.670361 3.046564
0.176202 0.128954
0.392170 1.062726
0.911188 3.651742
0.872288 4.401950
0.733107 3.022888
0.610239 2.874917
0.732739 2.946801
0.714825 2.893644
0.076386 0.072131
0.559009 1.748275
0.427258 1.912047
0.841875 3.710686
0.558918 1.719148
0.533241 2.174090
0.956665 3.656357
0.620393 3.522504
0.566120 2.234126
0.523258 1.859772
0.476884 2.097017
0.176408 0.001794
0.303094 1.231928
0.609731 2.953862
0.017774 -0.116803
0.622616 2.638864
0.886539 3.943428
0.148654 -0.328513
0.104350 -0.099866
0.116868 -0.030836
0.516514 2.359786
0.664896 3.212581
0.004327 0.188975
0.425559 1.904109
0.743671 3.007114
0.935185 3.845834
0.697300 3.079411
0.444551 1.939739
0.683753 2.880078
0.755993 3.063577
0.902690 4.116296
0.094491 -0.240963
0.873831 4.066299
0.991810 4.011834
0.185611 0.077710
0.694551 3.103069
0.657275 2.811897
0.118746 -0.104630
0.084302 0.025216
0.945341 4.330063
0.785827 3.087091
0.530933 2.269988
0.879594 4.010701
0.652770 3.119542
0.879338 3.723411
0.764739 2.792078
0.504884 2.192787
0.554203 2.081305
0.493209 1.714463
0.363783 0.885854
0.316465 1.028187
0.580283 1.951497
0.542898 1.709427
0.112661 0.144068
0.816742 3.880240
0.234175 0.921876
0.402804 1.979316
0.709423 3.085768
0.867298 3.476122
0.993392 3.993679
0.711580 3.077880
0.133643 -0.105365
0.052031 -0.164703
0.366806 1.096814
0.697521 3.092879
0.787262 2.987926
0.476710 2.061264
0.721417 2.746854
0.230376 0.716710
0.104397 0.103831
0.197834 0.023776
0.129291 -0.033299
0.528528 1.942286
0.009493 -0.006338
0.998533 3.808753
0.363522 0.652799
0.901386 4.053747
0.832693 4.569290
0.119002 -0.032773
0.487638 2.066236
0.153667 0.222785
0.238619 1.089268
0.208197 1.487788
0.750921 2.852033
0.183403 0.024486
0.995608 3.737750
0.151311 0.045017
0.126804 0.001238
0.983153 3.892763
0.772495 2.819376
0.784133 2.830665
0.056934 0.234633
0.425584 1.810782
0.998709 4.237235
0.707815 3.034768
0.413816 1.742106
0.217152 1.169250
0.360503 0.831165
0.977989 3.729376
0.507953 1.823205
0.920771 4.021970
0.210542 1.262939
0.928611 4.159518
0.580373 2.039114
0.841390 4.101837
0.681530 2.778672
0.292795 1.228284
0.456918 1.736620
0.134128 -0.195046
0.016241 -0.063215
0.691214 3.305268
0.582002 2.063627
0.303102 0.898840
0.622598 2.701692
0.525024 1.992909
0.996775 3.811393
0.881025 4.353857
0.723457 2.635641
0.676346 2.856311
0.254625 1.352682
0.488632 2.336459
0.519875 2.111651
0.160176 0.121726
0.609483 3.264605
0.531881 2.103446
0.321632 0.896855
0.845148 4.220850
0.012003 -0.217283
0.018883 -0.300577
0.071476 0.006014
3.2 ex2test.txt
0.421862 10.830241
0.105349 -2.241611
0.155196 21.872976
0.161152 2.015418
0.382632 -38.778979
0.017710 20.109113
0.129656 15.266887
0.613926 111.900063
0.409277 1.874731
0.807556 111.223754
0.593722 133.835486
0.953239 110.465070
0.257402 15.332899
0.645385 93.983054
0.563460 93.645277
0.408338 -30.719878
0.874394 91.873505
0.263805 -0.192752
0.411198 10.751118
0.449884 9.211901
0.646315 113.533660
0.673718 125.135638
0.805148 113.300462
0.759327 72.668572
0.519172 82.131698
0.741031 106.777146
0.030937 9.859127
0.268848 -34.137955
0.474901 -11.201301
0.588266 120.501998
0.893936 142.826476
0.870990 105.751746
0.430763 39.146258
0.057665 15.371897
0.100076 9.131761
0.980716 116.145896
0.235289 -13.691224
0.228098 16.089151
0.622248 99.345551
0.401467 -1.694383
0.960334 110.795415
0.031214 -5.330042
0.504228 96.003525
0.779660 75.921582
0.504496 101.341462
0.850974 96.293064
0.701119 102.333839
0.191551 5.072326
0.667116 92.310019
0.555584 80.367129
0.680006 132.965442
0.393899 38.605283
0.048940 -9.861871
0.963282 115.407485
0.655496 104.269918
0.576463 141.127267
0.675708 96.227996
0.853457 114.252288
0.003933 -12.182861
0.549512 97.927224
0.218967 -4.712462
0.659972 120.950439
0.008256 8.026816
0.099500 -14.318434
0.352215 -3.747546
0.874926 89.247356
0.635084 99.496059
0.039641 14.147109
0.665111 103.298719
0.156583 -2.540703
0.648843 119.333019
0.893237 95.209585
0.128807 5.558479
0.137438 5.567685
0.630538 98.462792
0.296084 -41.799438
0.632099 84.895098
0.987681 106.726447
0.744909 111.279705
0.862030 104.581156
0.080649 -7.679985
0.831277 59.053356
0.198716 26.878801
0.860932 90.632930
0.883250 92.759595
0.818003 110.272219
0.949216 115.200237
0.460078 -35.957981
0.561077 93.545761
0.863767 114.125786
0.476891 -29.774060
0.537826 81.587922
0.686224 110.911198
0.982327 119.114523
0.944453 92.033481
0.078227 30.216873
0.782937 92.588646
0.465886 2.222139
0.885024 90.247890
0.186077 7.144415
0.915828 84.010074
0.796649 115.572156
0.127821 28.933688
0.433429 6.782575
0.946796 108.574116
0.386915 -17.404601
0.561192 92.142700
0.182490 10.764616
0.878792 95.289476
0.381342 -6.177464
0.358474 -11.731754
0.270647 13.793201
0.488904 -17.641832
0.106773 5.684757
0.270112 4.335675
0.754985 75.860433
0.585174 111.640154
0.458821 12.029692
0.218017 -26.234872
0.583887 99.413850
0.923626 107.802298
0.833620 104.179678
0.870691 93.132591
0.249896 -8.618404
0.748230 109.160652
0.019365 34.048884
0.837588 101.239275
0.529251 115.514729
0.742898 67.038771
0.522034 64.160799
0.498982 3.983061
0.479439 24.355908
0.314834 -14.256200
0.753251 85.017092
0.479362 -17.480446
0.950593 99.072784
0.718623 58.080256
0.218720 -19.605593
0.664113 94.437159
0.942900 131.725134
0.314226 18.904871
0.284509 11.779346
0.004962 -14.624176
0.224087 -50.547649
0.974331 112.822725
0.894610 112.863995
0.167350 0.073380
0.753644 105.024456
0.632241 108.625812
0.314189 -6.090797
0.965527 87.418343
0.820919 94.610538
0.144107 -4.748387
0.072556 -5.682008
0.002447 29.685714
0.851007 79.632376
0.458024 -12.326026
0.627503 139.458881
0.422259 -29.827405
0.714659 63.480271
0.672320 93.608554
0.498592 37.112975
0.698906 96.282845
0.861441 99.699230
0.112425 -12.419909
0.164784 5.244704
0.481531 -18.070497
0.375482 1.779411
0.089325 -14.216755
0.036609 -6.264372
0.945004 54.723563
0.136608 14.970936
0.292285 -41.723711
0.029195 -0.660279
0.998307 100.124230
0.303928 -5.492264
0.957863 117.824392
0.815089 113.377704
0.466399 -10.249874
0.876693 115.617275
0.536121 102.997087
0.373984 -37.359936
0.565162 74.967476
0.085412 -21.449563
0.686411 64.859620
0.908752 107.983366
0.982829 98.005424
0.052766 -42.139502
0.777552 91.899340
0.374316 -3.522501
0.060231 10.008227
0.526225 87.317722
0.583872 67.104433
0.238276 10.615159
0.678747 60.624273
0.067649 15.947398
0.530182 105.030933
0.869389 104.969996
0.698410 75.460417
0.549430 82.558068