CART---回归树

本算法根据《机器学习实战》改编而来,对回归树的详细说明请参照原书,数据的下载地址 https://pan.baidu.com/s/1gfKrBqj,

以下为Matlab程序主程序:

clc;
clear;
%加载测试数据文件,前两列为坐标值,后两列为类标号
fileID = fopen('D:\matlabFile\CART\CART.txt');
DS=textscan(fileID,'%f %f');
fclose(fileID);
%将数据转为矩阵形式
Dataset=cat(2,DS{1},DS{2});
%用户设定参数
ops=[1,4];
%创建树
R=CreateTree(Dataset,ops);
DataSet=R{1,2};
SDS1=DataSet{1,1};
SDS2=DataSet{1,2};
%显示分组的数据集
scatter(SDS1(:,1),SDS1(:,2),'filled');
hold on
scatter(SDS2(:,1),SDS2(:,2),'filled');

建树函数CreateTree:

function Result=CreateTree(Dataset,ops)
%选择分割特征参数
R=ChooseSplitFeature(Dataset,ops);
if R(1)==0
    Result=R(2);
    return;
end
%构造回归树,
RegTree=cell(4,1);
%第一行为特征列的索引号
RegTree{1,1}=R(1);
%第二行为该特征的门限
RegTree{2,1}=R(2);
%根据特征值拆分数据集
DS=SplitDataset(Dataset,R(1),R(2));
%第三行存放左子树
RegTree{3,1}=CreateTree(DS{1,1},ops);
%第四行存放右子树
RegTree{4,1}=CreateTree(DS{1,2},ops);
%返回树以及拆分后的数据集
Result={RegTree,DS};
end

选择分割的特征的参数索引ChooseSplitFeature:

function Result=ChooseSplitFeature(Dataset,ops)
TolS=ops(1);
TolN=ops(2);
%如果最后一列只有一个相同的值,返回
if length(unique(Dataset(:,end)))==1
    Result=[0,mean(Dataset(:,end))];
    return;
end
[Row,Column]=size(Dataset);
%计算最后一列的方差
S=(var(Dataset(:,end)))*Row;
BestS=Inf;Index=1;Value=0;
for FeatIndex=1:Column-1
    Col=unique(Dataset(:,FeatIndex));
    for j=1:length(Col)
        R=SplitDataset(Dataset,FeatIndex,Col(j));
        DS1=R{1,1};
        DS2=R{1,2};
        if (size(DS1,1)<TolN)||(size(DS2,1)<TolN)
            continue;
        end
        NewS=(var(DS1(:,end)))*Row+(var(DS2(:,end)))*Row;
        if NewS<BestS
            Index=FeatIndex;
            Value=Col(j);
            BestS=NewS;
        end
    end
end
if (S-BestS)<TolS
    Result=[0,mean(Dataset(:,end))];
    return;    
end
R=SplitDataset(Dataset,Index,Value);
DS1=R{1,1};
DS2=R{1,2};
if (size(DS1,1)<TolN)||(size(DS2,1)<TolN)
    Result=[0,mean(Dataset(:,end))];
    return;  
end
Result=[Index,Value];
end

数据集分割函数SplitDataset:

function DS=SplitDataset(Dataset,F,Threshold)
%取出第F个特征列
Feature=Dataset(:,F);
%获得该特征大于门限的索引号
Index1=find(Feature>Threshold);
%取出于索引号相对应的数据
DS1=Dataset(Index1,:);
%获得该特征小于门限的索引号
Index2=find(Feature<=Threshold);
%取出于索引号相对应的数据
DS2=Dataset(Index2,:);
DS={DS1,DS2};
end

下面是实验结果图,仅供参考:

            
版权声明:本文为博主原创文章,未经博主允许不得转载。



  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值