matlab中cart树,matlab实现cart(回归分类树)

作为机器学习的小白和matlab的小白自己参照 python的 《机器学习实战》 写了一下分类回归树,这里记录一下。

关于决策树的基础概念就不过多介绍了,至于是分类还是回归。。我说不清楚。。我用的数据集是这个http://archive.ics.uci.edu/ml/datasets/Abalone 就是通过一些属性来预测鲍鱼有多少头,下面看一下

Length / continuous / mm / Longest shell measurement

Diameter/ continuous / mm / perpendicular to length

Height / continuous / mm / with meat in shell

Whole weight / continuous / grams / whole abalone

Shucked weight / continuous/ grams / weight of meat

Viscera weight / continuous / grams / gut weight (after bleeding)

Shell weight / continuous / grams / after being dried

Rings / integer / -- / +1.5 gives the age in years

这些属性除了最后的Rings是整数,可以看做是离散的,其他都是浮点数,是连续的。所以还是用cart中二分的思想,就是小于等于分一边,大于分一边。但是没有用gini指数,因为熵还是好一点。

参照《机器学习实战》代码有5个部分:getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集),chooseBestFeatureToSplit(寻找最佳分割点和阈值),createTree(建树),predict(预测)。

我按流程梳理一下,首先函数脚本来将数据集划分成,训练集和测试集,然后用训练集建树,用测试集测试,(更改后变成bootstrap sampleing)

dataset = importdata('abalone.data.txt') ;

origin_data=dataset.data ;

labels= {'Length';'Diam';'Height'; 'Whole';'Shucked';'Viscera';'Shell';'Rings'} ;

test_runtimes= 50;

ae= 0;

rr= 0;for i=1:test_runtimes

data= sampleWithReplace(origin_data) ;%bootstrap sampling

len= floor(length(data)/4*3) ;

train_data= data(1:len,:) ;

test_data=data(len:end,:) ;

test_y_truth=test_data(:,end) ;% tree = createTree(train_data,labels,0) ;% predict_y =predict(tree,test_data,labels) ;% com_matrix =[predict_y,test_y_truth] ;% count = sum(predict_y==test_y_truth) ;%disp(com_matrix) ;%disp(mae) ;%disp(rr) ;%plot single runtime% x = 1:1:size(test_y_truth,1) ;% plot(x,predict_y,'-b',x,test_y_truth,'-r') ;

ae= ae+sum(abs(predict_y-test_y_truth))/size(test_y_truth,1) ;

rr= rr+count/size(test_y_truth,1) ;%trian with office tools fitctree

std_tree= fitctree(train_data(:,1:7),train_data(:,end)) ;%view(std_tree) ;

std_y= predict(std_tree,test_data(:,1:7)) ;%disp([std_y,y]) ;

ae= ae+sum(abs(std_y-test_y_truth))/size(test_y_truth,1) ;

rr= rr+sum(std_y==test_y_truth)/size(test_y_truth,1) ;

end

mae= mae /test_runtimes ;

mrr= rr /test_runtimes ;

disp('mae') ;

disp(mae) ;

disp('mrr') ;

disp(mrr) ;

createTree函数:由于matlab没有指针,所以只能写成嵌套结构,就像tree{tree{tree}}这样。我们是递归实现的,但怎么样才会停止建树?条件是当前节点所有标签的类别一样,比如rings都为10,那说明这一个子集已经纯了,或者是这颗树的高度已经超出了我们设的阈值,就停止,第二种情况很可能当前节点下的数据集不纯,我们就找一个出现频率最高的类别代表该节点

function [ tree ] =createTree( dataset,labels,heightcount )

len= size(dataset,1) ;

templabel= dataset(1,end) ;

tree=templabel ;

max_depth= 5 ;%最大树高

flag= 1 ; %判断是否数据集中所有标签都一致了(纯的),是则返回for i=1:lenif templabel~=dataset(i,end) ;

flag= 0;

end

endif flag==1

return;

endif heightcount>max_depth

labelVec=dataset(:,end) ;

disp(labelVec) ;

element= 1:max(labelVec) ;

counts=histc(labelVec,element) ;

[~,max_idx] =max(counts) ;

tree=element(max_idx) ;return;

end

[bestFeat,bestT]=chooseBestFeatureToSplit(dataset) ;

bestFeatLabel=labels{bestFeat} ;

tree= struct ;%struct储存树结构

tree.bestFeatLabel=bestFeatLabel ;

tree.bestT=bestT ;

tree.greaterthan= createTree(splitDataset(dataset,bestFeat,bestT,1),labels,heightcount+1) ;%大于阈值部分的子树

tree.lessthan= createTree(splitDataset(dataset,bestFeat,bestT,2),labels,heightcount+1) ;%小于阈值部分的子树

end

chooseBestFeatureToSplit函数:在createTree时,每次递归都要找那个当前最佳的特征和阈值,也就是调用chooseBestFeatureToSplit函数,所以两层循环,第一层遍历每个属性,第二层本应该遍历每个属性下的值,但是那样计算量太大了,所以我就将值排序之后分成10端取中位数遍历,在里面找阈值,如果当前节点的数据子集已经不足10个里,那就把所有属性都遍历一哈

function [ bestFeat,bestT ] =chooseBestFeatureToSplit( dataset )

[~,numFeats] =size(dataset) ;

numFeats= numFeats-1 ;%除去标签那一列

baseEnt=getEnt(dataset) ;

baseInfoGain= 0;

bestFeat= -1;for i=1:numFeats

featVec=dataset(:,i) ;%由于值是连续的,所以对于特征向量组排序分成n段取中位数

sortedFeatVec= sort(featVec,'ascend') ;

lengthofT= floor(sqrt(length(sortedFeatVec))) ; %取向量长度开根号来确定阈值的个数if lengthofT<10lengthofT=length(sortedFeatVec) ;

selectedFeat=sortedFeatVec ;elsestep= floor(length(sortedFeatVec)/lengthofT) ;

selectedFeat= zeros(lengthofT,1) ;for j=1:lengthofT

head= (j-1)*step+1;

tail= j*step ;

subSortedFeatVec=sortedFeatVec(head:tail) ;

selectedFeat(j)=median(subSortedFeatVec) ;

end

endfor k=1:lengthofT

newEnt= 0;for l=1:2subDataset=splitDataset(dataset,i,selectedFeat(k),l) ;

prob= size(subDataset,1)/size(dataset,1) ;

newEnt= newEnt + prob*getEnt(subDataset) ;

end

infoGain= baseEnt -newEnt ;% disp('infoGain') ;%disp(infoGain) ;if(infoGain>baseInfoGain)

baseInfoGain=infoGain ;

bestFeat=i ;

bestT=selectedFeat(k) ;

end

end

end

end

计算信息增益(infoGain)的时候需要用到getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集)函数

splitDataset:

3350b879e51a2d3a4ab06ea2e0854724.gif

5c953341f0166a24adb5872919d817e9.gif

function [ retDataset ] =splitDataset(dataset,axis,value,arg )%axis 代表键值的位置 value表示阈值 返回划分后的dataset,arg表示取大于的部分(1)还是小于等于的部分if arg==1retDataset= dataset(dataset(:,axis)>value,:) ;elseretDataset= dataset(dataset(:,axis)<=value,:) ;

end

end

View Code

getEnt:

3350b879e51a2d3a4ab06ea2e0854724.gif

5c953341f0166a24adb5872919d817e9.gif

function [ ent ] =getEnt( data )%index present the label

[datalen,~] =size(data) ;

maxLabel=max(data(:,end)) ;

labelCountsMap= zeros(maxLabel,1) ;%rings are all numbersfor i=1:datalen

label=data(i,end) ;if labelCountsMap(label)~=0labelCountsMap(label)= labelCountsMap(label) + 1;elselabelCountsMap(label)= 1;

end

end

ent= 0;% disp('labelMap') ;%disp(labelCountsMap) ;for i=1:maxLabelif labelCountsMap(i)~=0prob= labelCountsMap(i)/datalen ;

ent= ent - prob*log2(prob) ;

end

end

end

View Code

最后预测函数:

3350b879e51a2d3a4ab06ea2e0854724.gif

5c953341f0166a24adb5872919d817e9.gif

function [ classVec ] =predict( tree , dataset , labels)%tree应由createTree函数生成

len= size(dataset,1) ;

classVec= zeros(len,1) ;for i=1:len

dataVec= dataset(i,1:end-1) ;

tempnode=tree ;while(isstruct(tempnode))

[~,tempFeatIdx] =ismember(tempnode.bestFeatLabel,labels) ;if(dataVec(tempFeatIdx)>tempnode.bestT)

tempnode=tempnode.greaterthan ;elsetempnode=tempnode.lessthan ;

end

end

classVec(i)=tempnode ;

end

end

View Code

更新了一下代码,加入了boostrap采样,就是有放回的采样,我是这样采用的,有多少个样本就进行多少次有放回采样,然后这个过程进行50次求均值。用了之后,官方的库正确率道理44%,而我的还在30%。。差距一下突显,还需继续学习。。

补充一下那个sampleWithReplace函数

function [ sample_data ] =sampleWithReplace( dataset )

len= size(dataset,1) ;

randidx= randsample(len,len,true) ;

sample_data=dataset(randidx,:) ;

end

内容来源于网络如有侵权请私信删除

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用MATLAB编写的CART决策代码: ```matlab % 数据预处理 load ionosphere; length = size(X, 1); rng(1); % 可复现 indices = crossvalind('Kfold', length, 5); % 用k折分类法将样本随机分为5部分 % 四份用来训练,一份进行测试 i = 1; test = (indices == i); train = ~test; X_train = X(train, :); Y_train = Y(train, :); X_test = X(test, :); Y_test = Y(test, :); % 构建CART算法分类 tree = fitctree(X_train, Y_train); view(tree, 'Mode', 'graph'); % 生成图 % 求取规则数量 rules_num = (tree.IsBranchNode == 0); rules_num = sum(rules_num); % 使用测试样本进行验证 Cart_result = predict(tree, X_test); Cart_result = cell2mat(Cart_result); Y_test = cell2mat(Y_test); Cart_result = (Cart_result == Y_test); % 统计准确率 Cart_length = size(Cart_result, 1); Cart_rate = (sum(Cart_result)) / Cart_length; disp(\['规则数:' num2str(rules_num)\]); disp(\['测试样本识别准确率:' num2str(Cart_rate)\]); ``` 这段代码首先加载了ionosphere数据集,并将数据集分为训练集和测试集。然后使用CART算法构建了一个分类,并生成了图。接着统计了分类的规则数量。最后使用测试样本对分类进行验证,并计算了测试样本的识别准确率。 #### 引用[.reference_title] - *1* [使用matlab实现决策cart算法(基于fitctree函数)](https://blog.csdn.net/u010356524/article/details/79848624)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值