CART回归树+预测 python实现

1.CART回归树代码

from numpy import *

#加载数据集
def load_data(fileName):
    dataset = []
    fr = open(fileName).readlines()
    for line in fr:
        curline = line.strip().split('\t')
        fltline = list(map(float, curline)) #加上list就是完整功能了,不然只用map会生成内存地址
        dataset.append(fltline)
    dataset = array(dataset) #转化为array数组,便于计算
    return dataset

#计算平均值,负责生成叶子节点,叶子节点根据公式就是求所在子区域的平均值
def regLeaf(dataset):
    return mean(dataset[:, -1])

#计算损失函数中的总方差=方差x样本个数=每个点与估计的差(叶节点的均值)的平方和
def regErr(dataset):
    return var(dataset[:, -1]) * shape(dataset)[0]

#根据特征索引和分裂点划分数据集为左右两部分
def splitDataset(dataset, index_of_feature, value):
    mat0 = dataset[nonzero(dataset[:, index_of_feature] <= value)[0], :]
    mat1 = dataset[nonzero(dataset[:, index_of_feature] > value)[0], :]
    return mat0, mat1

#选择最佳特征和相应的划分点
def choose_best_feature(dataset, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    min_error = ops[0] #最小下降损失
    min_samples = ops[1] #最小样本数
    #情形1
    if len(set(dataset[:, -1])) == 1: #停止切分的条件之一,数据集中都是同一个目标值
        return None, leafType(dataset) #返回最佳划分特征为None,切分点为 均值

    m, n = shape(dataset) #统计数据集的行和列
    Loss = errType(dataset) #损失函数
    bestLoss = inf #最优切分点的误差
    bestIndex = 0 #最优特征索引
    bestValue = 0 #最优切分点的值

    for featureIndex in range(n-1): #遍历特征
        for value in set(dataset[:, featureIndex]):
            mat0, mat1 = splitDataset(dataset, featureIndex, value)
            if(shape(mat0)[0] < min_samples) or (shape(mat1)[0] < min_samples): #若两个子区域的叶子节点小于4个样本,则跳过此划分点
                continue
            curLoss = errType(mat0) + errType(mat1)
            if curLoss < bestLoss:
                bestLoss = curLoss
                bestIndex = featureIndex
                bestValue = value
    if Loss - bestLoss < min_error: #若切分之前的损失-切分之后的损失< min_error,那么就是损失减少不明显,停止切分
        return None, leafType(dataset)
    mat0, mat1 = splitDataset(dataset, featureIndex, value)

    if(shape(mat0)[0] < min_samples) or (shape(mat1)[0] < min_samples):
        return None, leafType(dataset)

    return bestIndex, bestValue

#生成回归树
def generate_tree(dataset, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feature_index, value = choose_best_feature(dataset, leafType, errType, ops) #最优切分点划分
    
    if feature_index == None:
        return value

    reTree = {}
    reTree['spInd'] = feature_index #最优切分特征的索引
    reTree['spVal'] = value #最优切分点的值
    lSet, rSet = splitDataset(dataset, feature_index, value)
    reTree['left'] = generate_tree(lSet, leafType, errType, ops)
    reTree['right'] = generate_tree(rSet, leafType, errType, ops)
    return  reTree

#-------剪枝处理--------
#判断是叶子还是树
def isTree(obj):
    return (type(obj).__name__ == 'dict')

#获取左右叶子的平均值
def getMean(reTree):
    if isTree(reTree['left']):
        reTree['left'] = getMean(reTree['left'])
    if isTree(reTree['right']):
        reTree['right'] = getMean(reTree['right'])
    return (reTree['left'] + reTree['right']) / 2.0

#后剪枝
def postpruning(reTree, test_data):
    if shape(test_data)[0] == 0:
        print("判断测试集为空,执行过吗?")
        return getMean(reTree)
    if (isTree(reTree['left']) or isTree(reTree['right'])):
        lSet, rSet = splitDataset(test_data, reTree['spInd'], reTree['spVal'])
    if isTree(reTree['left']):
        reTree['left'] = postpruning(reTree['left'], lSet)
    if isTree(reTree['right']):
        reTree['right'] = postpruning(reTree['right'], rSet)
    if not isTree(reTree['left']) and not isTree(reTree['right']):
        lSet, rSet = splitDataset(test_data, reTree['spInd'], reTree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1] - reTree['left'], 2)) + sum(power(rSet[:, -1] - reTree['right'], 2)) #剪枝前的误差
        treeMean = (reTree['left'] + reTree['right']) / 2.0
        errorMerge = sum(power(test_data[:, -1] - treeMean, 2)) #剪枝后的误差
        if errorMerge < errorNoMerge:
            print("剪枝")
            return treeMean #剪枝后的节点值为原来左右叶子结点的均值
        else:
            return reTree  #不剪枝
    else:
        return reTree




#回归树预测一个样本
def reTree_predict_one_test(reTree, one_test_example):
    first_feature_index = reTree[list(reTree.keys())[0]] #获取第一个结点的特征索引
    feature_spVal = reTree[list(reTree.keys())[1]] #获取第一个结点的分裂点
    predict_val = 0.0

    if one_test_example[first_feature_index] <= feature_spVal:
        if type(reTree['left']).__name__ == 'dict':
            predict_val = reTree_predict_one_test(reTree['left'], one_test_example)
        else:
            predict_val = reTree['left']
    else:
        if type(reTree['right']).__name__ == 'dict':
            predict_val = reTree_predict_one_test(reTree['right'], one_test_example)
        else:
            predict_val = reTree['right']
    return predict_val

#回归树预测所有测试数据结果
def reTree_predict(reTree, test_data):
    classLabel = []
    for one_test in test_data:
        classLabel.append(reTree_predict_one_test(reTree, one_test))
    return classLabel

if __name__ == '__main__':
    myData = load_data('./ex2.txt')
    reTree = generate_tree(myData)
    reTree_str = str(reTree)
    print("未剪枝之前的决策树:")
    print(reTree_str)
    test_example = [0.8935465]
    predict_val = reTree_predict_one_test(reTree, test_example)
    print(predict_val)
    print("剪枝之后的决策树:")
    testData = load_data('./ex2test.txt')
    postPrunTree = postpruning(reTree, testData)
    print(postPrunTree)
    predict_val = reTree_predict_one_test(postPrunTree, test_example)
    print(predict_val)


2.运行结果

在这里插入图片描述

3.数据集

3.1 ex2.txt

0.409175	1.883180
0.182603	0.063908
0.663687	3.042257
0.517395	2.305004
0.013643	-0.067698
0.469643	1.662809
0.725426	3.275749
0.394350	1.118077
0.507760	2.095059
0.237395	1.181912
0.057534	0.221663
0.369820	0.938453
0.976819	4.149409
0.616051	3.105444
0.413700	1.896278
0.105279	-0.121345
0.670273	3.161652
0.952758	4.135358
0.272316	0.859063
0.303697	1.170272
0.486698	1.687960
0.511810	1.979745
0.195865	0.068690
0.986769	4.052137
0.785623	3.156316
0.797583	2.950630
0.081306	0.068935
0.659753	2.854020
0.375270	0.999743
0.819136	4.048082
0.142432	0.230923
0.215112	0.816693
0.041270	0.130713
0.044136	-0.537706
0.131337	-0.339109
0.463444	2.124538
0.671905	2.708292
0.946559	4.017390
0.904176	4.004021
0.306674	1.022555
0.819006	3.657442
0.845472	4.073619
0.156258	0.011994
0.857185	3.640429
0.400158	1.808497
0.375395	1.431404
0.885807	3.935544
0.239960	1.162152
0.148640	-0.227330
0.143143	-0.068728
0.321582	0.825051
0.509393	2.008645
0.355891	0.664566
0.938633	4.180202
0.348057	0.864845
0.438898	1.851174
0.781419	2.761993
0.911333	4.075914
0.032469	0.110229
0.499985	2.181987
0.771663	3.152528
0.670361	3.046564
0.176202	0.128954
0.392170	1.062726
0.911188	3.651742
0.872288	4.401950
0.733107	3.022888
0.610239	2.874917
0.732739	2.946801
0.714825	2.893644
0.076386	0.072131
0.559009	1.748275
0.427258	1.912047
0.841875	3.710686
0.558918	1.719148
0.533241	2.174090
0.956665	3.656357
0.620393	3.522504
0.566120	2.234126
0.523258	1.859772
0.476884	2.097017
0.176408	0.001794
0.303094	1.231928
0.609731	2.953862
0.017774	-0.116803
0.622616	2.638864
0.886539	3.943428
0.148654	-0.328513
0.104350	-0.099866
0.116868	-0.030836
0.516514	2.359786
0.664896	3.212581
0.004327	0.188975
0.425559	1.904109
0.743671	3.007114
0.935185	3.845834
0.697300	3.079411
0.444551	1.939739
0.683753	2.880078
0.755993	3.063577
0.902690	4.116296
0.094491	-0.240963
0.873831	4.066299
0.991810	4.011834
0.185611	0.077710
0.694551	3.103069
0.657275	2.811897
0.118746	-0.104630
0.084302	0.025216
0.945341	4.330063
0.785827	3.087091
0.530933	2.269988
0.879594	4.010701
0.652770	3.119542
0.879338	3.723411
0.764739	2.792078
0.504884	2.192787
0.554203	2.081305
0.493209	1.714463
0.363783	0.885854
0.316465	1.028187
0.580283	1.951497
0.542898	1.709427
0.112661	0.144068
0.816742	3.880240
0.234175	0.921876
0.402804	1.979316
0.709423	3.085768
0.867298	3.476122
0.993392	3.993679
0.711580	3.077880
0.133643	-0.105365
0.052031	-0.164703
0.366806	1.096814
0.697521	3.092879
0.787262	2.987926
0.476710	2.061264
0.721417	2.746854
0.230376	0.716710
0.104397	0.103831
0.197834	0.023776
0.129291	-0.033299
0.528528	1.942286
0.009493	-0.006338
0.998533	3.808753
0.363522	0.652799
0.901386	4.053747
0.832693	4.569290
0.119002	-0.032773
0.487638	2.066236
0.153667	0.222785
0.238619	1.089268
0.208197	1.487788
0.750921	2.852033
0.183403	0.024486
0.995608	3.737750
0.151311	0.045017
0.126804	0.001238
0.983153	3.892763
0.772495	2.819376
0.784133	2.830665
0.056934	0.234633
0.425584	1.810782
0.998709	4.237235
0.707815	3.034768
0.413816	1.742106
0.217152	1.169250
0.360503	0.831165
0.977989	3.729376
0.507953	1.823205
0.920771	4.021970
0.210542	1.262939
0.928611	4.159518
0.580373	2.039114
0.841390	4.101837
0.681530	2.778672
0.292795	1.228284
0.456918	1.736620
0.134128	-0.195046
0.016241	-0.063215
0.691214	3.305268
0.582002	2.063627
0.303102	0.898840
0.622598	2.701692
0.525024	1.992909
0.996775	3.811393
0.881025	4.353857
0.723457	2.635641
0.676346	2.856311
0.254625	1.352682
0.488632	2.336459
0.519875	2.111651
0.160176	0.121726
0.609483	3.264605
0.531881	2.103446
0.321632	0.896855
0.845148	4.220850
0.012003	-0.217283
0.018883	-0.300577
0.071476	0.006014

3.2 ex2test.txt

0.421862	10.830241
0.105349	-2.241611
0.155196	21.872976
0.161152	2.015418
0.382632	-38.778979
0.017710	20.109113
0.129656	15.266887
0.613926	111.900063
0.409277	1.874731
0.807556	111.223754
0.593722	133.835486
0.953239	110.465070
0.257402	15.332899
0.645385	93.983054
0.563460	93.645277
0.408338	-30.719878
0.874394	91.873505
0.263805	-0.192752
0.411198	10.751118
0.449884	9.211901
0.646315	113.533660
0.673718	125.135638
0.805148	113.300462
0.759327	72.668572
0.519172	82.131698
0.741031	106.777146
0.030937	9.859127
0.268848	-34.137955
0.474901	-11.201301
0.588266	120.501998
0.893936	142.826476
0.870990	105.751746
0.430763	39.146258
0.057665	15.371897
0.100076	9.131761
0.980716	116.145896
0.235289	-13.691224
0.228098	16.089151
0.622248	99.345551
0.401467	-1.694383
0.960334	110.795415
0.031214	-5.330042
0.504228	96.003525
0.779660	75.921582
0.504496	101.341462
0.850974	96.293064
0.701119	102.333839
0.191551	5.072326
0.667116	92.310019
0.555584	80.367129
0.680006	132.965442
0.393899	38.605283
0.048940	-9.861871
0.963282	115.407485
0.655496	104.269918
0.576463	141.127267
0.675708	96.227996
0.853457	114.252288
0.003933	-12.182861
0.549512	97.927224
0.218967	-4.712462
0.659972	120.950439
0.008256	8.026816
0.099500	-14.318434
0.352215	-3.747546
0.874926	89.247356
0.635084	99.496059
0.039641	14.147109
0.665111	103.298719
0.156583	-2.540703
0.648843	119.333019
0.893237	95.209585
0.128807	5.558479
0.137438	5.567685
0.630538	98.462792
0.296084	-41.799438
0.632099	84.895098
0.987681	106.726447
0.744909	111.279705
0.862030	104.581156
0.080649	-7.679985
0.831277	59.053356
0.198716	26.878801
0.860932	90.632930
0.883250	92.759595
0.818003	110.272219
0.949216	115.200237
0.460078	-35.957981
0.561077	93.545761
0.863767	114.125786
0.476891	-29.774060
0.537826	81.587922
0.686224	110.911198
0.982327	119.114523
0.944453	92.033481
0.078227	30.216873
0.782937	92.588646
0.465886	2.222139
0.885024	90.247890
0.186077	7.144415
0.915828	84.010074
0.796649	115.572156
0.127821	28.933688
0.433429	6.782575
0.946796	108.574116
0.386915	-17.404601
0.561192	92.142700
0.182490	10.764616
0.878792	95.289476
0.381342	-6.177464
0.358474	-11.731754
0.270647	13.793201
0.488904	-17.641832
0.106773	5.684757
0.270112	4.335675
0.754985	75.860433
0.585174	111.640154
0.458821	12.029692
0.218017	-26.234872
0.583887	99.413850
0.923626	107.802298
0.833620	104.179678
0.870691	93.132591
0.249896	-8.618404
0.748230	109.160652
0.019365	34.048884
0.837588	101.239275
0.529251	115.514729
0.742898	67.038771
0.522034	64.160799
0.498982	3.983061
0.479439	24.355908
0.314834	-14.256200
0.753251	85.017092
0.479362	-17.480446
0.950593	99.072784
0.718623	58.080256
0.218720	-19.605593
0.664113	94.437159
0.942900	131.725134
0.314226	18.904871
0.284509	11.779346
0.004962	-14.624176
0.224087	-50.547649
0.974331	112.822725
0.894610	112.863995
0.167350	0.073380
0.753644	105.024456
0.632241	108.625812
0.314189	-6.090797
0.965527	87.418343
0.820919	94.610538
0.144107	-4.748387
0.072556	-5.682008
0.002447	29.685714
0.851007	79.632376
0.458024	-12.326026
0.627503	139.458881
0.422259	-29.827405
0.714659	63.480271
0.672320	93.608554
0.498592	37.112975
0.698906	96.282845
0.861441	99.699230
0.112425	-12.419909
0.164784	5.244704
0.481531	-18.070497
0.375482	1.779411
0.089325	-14.216755
0.036609	-6.264372
0.945004	54.723563
0.136608	14.970936
0.292285	-41.723711
0.029195	-0.660279
0.998307	100.124230
0.303928	-5.492264
0.957863	117.824392
0.815089	113.377704
0.466399	-10.249874
0.876693	115.617275
0.536121	102.997087
0.373984	-37.359936
0.565162	74.967476
0.085412	-21.449563
0.686411	64.859620
0.908752	107.983366
0.982829	98.005424
0.052766	-42.139502
0.777552	91.899340
0.374316	-3.522501
0.060231	10.008227
0.526225	87.317722
0.583872	67.104433
0.238276	10.615159
0.678747	60.624273
0.067649	15.947398
0.530182	105.030933
0.869389	104.969996
0.698410	75.460417
0.549430	82.558068

4.参考资料

1.python实现CART回归树
2.实验三:CART回归决策树python实现(两个测试集)(二)|机器学习

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
CART(Classification And Regression Tree)是一种决策树算法,可用于分类和回归分析。CART分类回归算法是由Breiman等人在1984年提出的,它是一种基于结构对数据进行分类和预测的方法,可以用于分类和回归问题。 CART分类回归算法的主要思想是通过将数据集分成多个小的子集,并且每个子集内部的数据具有较高的相似性,不同子集之间的数据具有较大的差异性。这样,就可以通过对每个子集进行分析来对整个数据集进行分析。在决策树中,每个节点表示一个特征变量,每个分支代表该特征变量的不同取值,每个叶子节点代表一个类别或一个数值。 下面是使用Python实现CART分类回归分析的步骤: 步骤1:导入所需的库 ```python import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_graphviz from sklearn.externals.six import StringIO from IPython.display import Image import pydotplus ``` 步骤2:准备数据 ```python data = pd.read_csv("data.csv") X = data.iloc[:, :-1] y = data.iloc[:, -1] ``` 步骤3:训练模型 ```python model = DecisionTreeClassifier() model.fit(X, y) ``` 步骤4:生成决策树图形 ```python dot_data = StringIO() export_graphviz(model, out_file=dot_data, filled=True, rounded=True, special_characters=True) graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) Image(graph.create_png()) ``` 以上就是使用Python实现CART分类回归分析的基本步骤。其中,第一步是导入所需库,第二步是准备数据,第三步是训练模型,第四步是生成决策树图形。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小凉爽&玉米粒

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值