matlab中cart pos 怎么设置,matlab实现cart(回归分类树)

作为机器学习的小白和matlab的小白自己参照 python的 《机器学习实战》 写了一下分类回归树,这里记录一下。

关于决策树的基础概念就不过多介绍了,至于是分类还是回归。。我说不清楚。。我用的数据集是这个http://archive.ics.uci.edu/ml/datasets/Abalone 就是通过一些属性来预测鲍鱼有多少头,下面看一下

Length / continuous / mm / Longest shell measurement

Diameter/ continuous / mm / perpendicular to length

Height / continuous / mm / with meat in shell

Whole weight / continuous / grams / whole abalone

Shucked weight / continuous/ grams / weight of meat

Viscera weight / continuous / grams / gut weight (after bleeding)

Shell weight / continuous / grams / after being dried

Rings / integer / -- / +1.5 gives the age in years

这些属性除了最后的Rings是整数,可以看做是离散的,其他都是浮点数,是连续的。所以还是用cart中二分的思想,就是小于等于分一边,大于分一边。但是没有用gini指数,因为熵还是好一点。

参照《机器学习实战》代码有5个部分:getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集),chooseBestFeatureToSplit(寻找最佳分割点和阈值),createTree(建树),predict(预测)。

我按流程梳理一下,首先函数脚本来将数据集划分成,训练集和测试集,然后用训练集建树,用测试集测试,(更改后变成bootstrap sampleing)

dataset = importdata('abalone.data.txt') ;

origin_data = dataset.data ;

labels = {'Length';'Diam';'Height'; 'Whole';'Shucked';'Viscera';'Shell';'Rings'} ;

test_runtimes = ;

ae = ;

rr = ;

for i=:test_runtimes

data = sampleWithReplace(origin_data) ;%bootstrap sampling

len = floor(length(data)/*) ;

train_data = data(:len,:) ;

test_data = data(len:end,:) ;

test_y_truth = test_data(:,end) ;

% tree = createTree(train_data,labels,) ;

% predict_y = predict(tree,test_data,labels) ;

% com_matrix = [predict_y,test_y_truth] ;

% count = sum(predict_y==test_y_truth) ;

% disp(com_matrix) ;

% disp(mae) ;

% disp(rr) ;

%plot single runtime

% x = ::size(test_y_truth,) ;

% plot(x,predict_y,'-b',x,test_y_truth,'-r') ;

ae = ae+sum(abs(predict_y-test_y_truth))/size(test_y_truth,) ;

rr = rr+count/size(test_y_truth,) ;

%trian with office tools fitctree

std_tree = fitctree(train_data(:,:),train_data(:,end)) ;

% view(std_tree) ;

std_y = predict(std_tree,test_data(:,:)) ;

% disp([std_y,y]) ;

ae = ae+sum(abs(std_y-test_y_truth))/size(test_y_truth,) ;

rr = rr+sum(std_y==test_y_truth)/size(test_y_truth,) ;

end

mae = mae / test_runtimes ;

mrr = rr / test_runtimes ;

disp('mae') ;

disp(mae) ;

disp('mrr') ;

disp(mrr) ;

createTree函数:由于matlab没有指针,所以只能写成嵌套结构,就像tree{tree{tree}}这样。我们是递归实现的,但怎么样才会停止建树?条件是当前节点所有标签的类别一样,比如rings都为10,那说明这一个子集已经纯了,或者是这颗树的高度已经超出了我们设的阈值,就停止,第二种情况很可能当前节点下的数据集不纯,我们就找一个出现频率最高的类别代表该节点

function [ tree ] = createTree( dataset,labels,heightcount )

len = size(dataset,) ;

templabel = dataset(,end) ;

tree = templabel ;

max_depth = ;%最大树高

flag = ; %判断是否数据集中所有标签都一致了(纯的),是则返回

for i=:len

if templabel~=dataset(i,end) ;

flag = ;

end

end

if flag==

return ;

end

if heightcount>max_depth

labelVec = dataset(:,end) ;

disp(labelVec) ;

element = :max(labelVec) ;

counts = histc(labelVec,element) ;

[~,max_idx] = max(counts) ;

tree = element(max_idx) ;

return ;

end

[bestFeat,bestT] = chooseBestFeatureToSplit(dataset) ;

bestFeatLabel = labels{bestFeat} ;

tree = struct ;%struct储存树结构

tree.bestFeatLabel = bestFeatLabel ;

tree.bestT = bestT ;

tree.greaterthan = createTree(splitDataset(dataset,bestFeat,bestT,),labels,heightcount+) ;%大于阈值部分的子树

tree.lessthan = createTree(splitDataset(dataset,bestFeat,bestT,),labels,heightcount+) ;%小于阈值部分的子树

end

chooseBestFeatureToSplit函数:在createTree时,每次递归都要找那个当前最佳的特征和阈值,也就是调用chooseBestFeatureToSplit函数,所以两层循环,第一层遍历每个属性,第二层本应该遍历每个属性下的值,但是那样计算量太大了,所以我就将值排序之后分成10端取中位数遍历,在里面找阈值,如果当前节点的数据子集已经不足10个里,那就把所有属性都遍历一哈

function [ bestFeat,bestT ] = chooseBestFeatureToSplit( dataset )

[~,numFeats] = size(dataset) ;

numFeats = numFeats- ;%除去标签那一列

baseEnt = getEnt(dataset) ;

baseInfoGain = ;

bestFeat = - ;

for i=:numFeats

featVec = dataset(:,i) ;

%由于值是连续的,所以对于特征向量组排序分成n段取中位数

sortedFeatVec = sort(featVec,'ascend') ;

lengthofT = floor(sqrt(length(sortedFeatVec))) ; %取向量长度开根号来确定阈值的个数

if lengthofT<

lengthofT = length(sortedFeatVec) ;

selectedFeat = sortedFeatVec ;

else

step = floor(length(sortedFeatVec)/lengthofT) ;

selectedFeat = zeros(lengthofT,) ;

for j=:lengthofT

head = (j-)*step+ ;

tail = j*step ;

subSortedFeatVec = sortedFeatVec(head:tail) ;

selectedFeat(j) = median(subSortedFeatVec) ;

end

end

for k=:lengthofT

newEnt = ;

for l=:

subDataset = splitDataset(dataset,i,selectedFeat(k),l) ;

prob = size(subDataset,)/size(dataset,) ;

newEnt = newEnt + prob*getEnt(subDataset) ;

end

infoGain = baseEnt - newEnt ;

% disp('infoGain') ;

% disp(infoGain) ;

if(infoGain>baseInfoGain)

baseInfoGain = infoGain ;

bestFeat= i ;

bestT = selectedFeat(k) ;

end

end

end

end

计算信息增益(infoGain)的时候需要用到getEnt(获取信息熵),splitDataset(通过属性和阈值分割数据集)函数

splitDataset:

function [ retDataset ] = splitDataset(dataset,axis,value,arg )

%axis 代表键值的位置 value表示阈值 返回划分后的dataset,arg表示取大于的部分()还是小于等于的部分

if arg==

retDataset = dataset(dataset(:,axis)>value,:) ;

else

retDataset = dataset(dataset(:,axis)<=value,:) ;

end

end

getEnt:

function [ ent ] = getEnt( data )

%index present the label

[datalen,~] = size(data) ;

maxLabel = max(data(:,end)) ;

labelCountsMap = zeros(maxLabel,) ;%rings are all numbers

for i=:datalen

label = data(i,end) ;

if labelCountsMap(label)~=

labelCountsMap(label) = labelCountsMap(label) + ;

else

labelCountsMap(label) = ;

end

end

ent = ;

% disp('labelMap') ;

% disp(labelCountsMap) ;

for i=:maxLabel

if labelCountsMap(i)~=

prob = labelCountsMap(i)/datalen ;

ent = ent - prob*log2(prob) ;

end

end

end

最后预测函数:

function [ classVec ] = predict( tree , dataset , labels)

%tree应由createTree函数生成

len = size(dataset,) ;

classVec = zeros(len,) ;

for i=:len

dataVec = dataset(i,:end-) ;

tempnode = tree ;

while(isstruct(tempnode))

[~,tempFeatIdx] = ismember(tempnode.bestFeatLabel,labels) ;

if(dataVec(tempFeatIdx)>tempnode.bestT)

tempnode = tempnode.greaterthan ;

else

tempnode = tempnode.lessthan ;

end

end

classVec(i) = tempnode ;

end

end

更新了一下代码,加入了boostrap采样,就是有放回的采样,我是这样采用的,有多少个样本就进行多少次有放回采样,然后这个过程进行50次求均值。用了之后,官方的库正确率道理44%,而我的还在30%。。差距一下突显,还需继续学习。。

补充一下那个sampleWithReplace函数

function [ sample_data ] = sampleWithReplace( dataset )

len = size(dataset,) ;

randidx = randsample(len,len,true) ;

sample_data = dataset(randidx,:) ;

end

连续值的CART(分类回归树)原理和实现

上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现 思考连续值和离散值的不同之处: 二分子树的时候不同:离散值需要求出最优的两个 ...

用cart(分类回归树)作为弱分类器实现adaboost

在之前的决策树到集成学习里我们说了决策树和集成学习的基本概念(用了adaboost昨晚集成学习的例子),其后我们分别学习了决策树分类原理和adaboost原理和实现, 上两篇我们学习了cart(决策分 ...

CART回归树

决策树算法原理(ID3,C4.5) 决策树算法原理(CART分类树) 决策树的剪枝 CART回归树模型表达式: 其中,数据空间被划分为R1~Rm单元,每个单元有一个固定的输出值Cm.这样可以计算模型输 ...

决策树算法原理&lpar;CART分类树&rpar;

决策树算法原理(ID3,C4.5) CART回归树 决策树的剪枝 在决策树算法原理(ID3,C4.5)中,提到C4.5的不足,比如模型是用较为复杂的熵来度量,使用了相对较为复杂的多叉树,只能处理分类不 ...

大白话5分钟带你走进人工智能-第二十六节决策树系列之Cart回归树及其参数&lpar;5&rpar;

第二十六节决策树系列之Cart回归树及其参数(5) 上一节我们讲了不同的决策树对应的计算纯度的计算方法, ...

机器学习实战---决策树CART回归树实现

机器学习实战---决策树CART简介及分类树实现 一:对比分类树 CART回归树和CART分类树的建立算法大部分是类似的,所以这里我们只讨论CART回归树和CART分类树的建立算法不同的地方.首先,我 ...

机器学习实战---决策树CART简介及分类树实现

https://blog.csdn.net/weixin_43383558/article/details/84303339?utm_medium=distribute.pc_relevant_t0. ...

sklearn 学习之分类树

概要 基于 sklearn 包自带的 iris 数据集,了解一下分类树的各种参数设置以及代表的意义.   iris 数据集介绍 iris 数据集包含 150 个样本,对应数据集的每行数据,每行数据包含 ...

sklearn CART决策树分类

sklearn CART决策树分类 决策树是一种常用的机器学习方法,可以用于分类和回归.同时,决策树的训练结果非常容易理解,而且对于数据预处理的要求也不是很高. 理论部分 比较经典的决策树是ID3.C ...

随机推荐

VueJs2&period;0建议学习路线

最近VueJs确实火了一把,自从Vue2.0发布后,Vue就成了前端领域的热门话题,github也突破了三万的star,那么对于新手来说,如何高效快速的学习Vue2.0呢. 既然大家会看这篇文章,那么 ...

HTML之&lt&semi;&excl;DOCTYPE&gt&semi; 标签介绍

实例:

文档的标题 & ...

suse linux编译安装GCC报错

gcc编译安装过程 1.先安装三个库 gmp mprc mpc 这三个库的源码要到官网去下载 1)安装gmp:首先建立源码同级目录 gmp-build,输入命令,第一次编译不通过,发现缺少一个叫m4的 ...

nginx 配置访问正则匹配

server{ listen 80; server_name api.zyy.com; root /var/www/api_zyy; index index.php; location ~ /asse ...

myeclipse搭建svn插件

在网上查了一下,安装的方法有几种,这里给大家推荐一种快速安装的方法. //第一步 : 下载 site-1.6.5.zip //===================================== ...

Codeforces&num;362

A题 题意:给定一串数列,t,t+s,t+s+1,t+2s,t+2s+1......问某一个数是否是数列当中的 题意:只需判断(x-t)与(x-t-1)能否整除s即可,注意起始时的判断 #includ ...

appium启动运行log分析

1.手动启动appium 服务 > Launching Appium server with command: C:\Program Files (x86)\Appium\node.exe ...

Python 集合 深浅copy

一,集合. 集合是无序的,不重复的数据集合,它里面的元素是可哈希的(不可变类型),但是集合本身是不可哈希(所以集合做不了字典的键)的.以下是集合最重要的两点: 去重,把一个列表变成集合,就自动去重了. ...

Codeforces 960F - Pathwalks

960F - Pathwalks 思路: ORZ 杜老师 用map写1e5个树状数组,骚操作 记Q为query和update次数,则节点个数约为Q*log(N) 代码: #include

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值