CART算法实现之建树篇

代码下载:  http://pan.baidu.com/s/1kTr9eAj


以下是我对CART的一些理解,剪枝部分随后会和大家一起分


享,欢迎大家一起讨论。


分类回归树算法:CART (Classification And Regression Tree) 算法采用一种二分递归分割的技术,将当前的样本集分为两个子样本集,使得生成的每个非叶子节点都有两个分支。CART算法生成的决策树是结构简洁的二叉树。

CART使用如下结构的学习样本集:L={X_1,X_2,...X_3,Y}
其中,X_1-X_m称为属性向量,既可以为连续属性,也可以是离散属性,Y称为标签向量,他们所包含的属性既可以是有序或连续的(ordered or numerical),也可以是离散的。当Y为有序数量值时,称为回归树;当Y是离散值时,称为分类树。

分类树两个基本思想:第一个是将训练样本进行递归地划分自变量空间进行建树的想法,第二个是用验证数据进行剪枝。



生成决策树
建树:在分类回归树中,我们用类别集Targets表示因变量,选取的属性集 Attributelist 表示自变量,通过递归的方式利用 Attributelist 把p 维空间划分为不重叠的矩形.


function tree = make_tree(patterns, targets,Dlength, split_type, inc_node)
%Build a tree recursively
%Dlength为样本个数

% 判断是否为单节点树
if (length(unique(targets)) == 1),
   %There is only one type of targets, and this generates a warning, so deal with it separately
   tree.right      = [];
   tree.left       = [];
   tree.Raction    = targets(1);
   tree.Laction    = targets(1);
   tree.Resub      = 0;
   tree.label      = targets(1);
   return
end
% Ni始终等于属性个数
[Ni, M]         = size(patterns);
Nt              = unique(targets);   % 分类情况
N               = hist(targets, Nt); % N为分类个数统计情况
tree.Resub      = ( M - max(N) )/ Dlength;
tree.label      = Nt(N == max(N));
% if T都属于同一类别or T中只剩下 一个样本
% 结束条件
% 分配类别的方法可以用当前节点中出现最多的类别
% tree's last level     | tree.Resub < 1e-3
if ((sum(N < Dlength*inc_node) == length(Nt) - 1) | (M == 1)), 
   %No further splitting is neccessary
   tree.right      = [];
   tree.left       = [];
   
   if (length(Nt) ~= 1),
       % 找到实例数最大的类作为该结点的类标记
      MLlabel	   = find(N == max(N)); 
   else
      MLlabel 	   = 1;
   end
   tree.Raction    = Nt(MLlabel);
   tree.Laction    = Nt(MLlabel);
else
   %Split the node according to the splitting criterion  
   %                                分裂标准
   deltaI		= zeros(1,Ni);     %   Ni=2
   split_point  = zeros(1,Ni);
   
   for i = 1:Ni,
       % 'CARTfunctions' 得到Gini最小的划分
       % 'feval' 调用函数
       % varargin提供了一种函数可变参数列表机制
       % split_point(i) = fminbnd('CARTfunctions', min(patterns(i,:)), max(patterns(i,:)), op, patterns, targets, i, split_type);
      split_point(i)    = findminGini( patterns, targets, i );
      I(i)				= feval('CARTfunctions', split_point(i), patterns, targets, i);
   end
   
   % 选择最优特征和最优切分点
   [m, dim]		  = min(I);
   loc			  = split_point(dim);
   %So, the split is to be on dimention 'dim' at location 'loc'
   indices		  = 1:M;       %M为样本个数
   tree.Raction   = ['patterns(' num2str(dim) ',indices) ~=  ' num2str(loc)];
   tree.Laction   = ['patterns(' num2str(dim) ',indices) == ' num2str(loc)];
   in_right       = find(eval(tree.Raction));
   in_left        = find(eval(tree.Laction));
   
   if isempty(in_right) | isempty(in_left)
      %No possible split found
	   tree.right      = [];
       tree.left       = [];
	   if (length(Nt) ~= 1),
   	   MLlabel  	   = find(N == max(N));
	   else
   	   MLlabel 		   = 1;
	   end
   	tree.Raction       = Nt(MLlabel);
   	tree.Laction       = Nt(MLlabel);
   else
	   %...It's possible to build new nodes
       %   递归 targets(in_right)...  targets(in_left)
   	tree.right = make_tree(patterns(:,in_right), targets(in_right), Dlength, split_type, inc_node);
   	tree.left  = make_tree(patterns(:,in_left), targets(in_left), Dlength, split_type, inc_node);      
    
   end
   
end

function index = findminGini( patterns, targets, dim)
%UNTITLED Summary of this function goes here
%   Detailed explanation goes here
%   Fin minimum Gini 
Nt      = unique(patterns(dim,:)); 

% 测试属性的每个取值
for i = 1:length(Nt)
    Gini(i) = CARTfunctions(Nt(i), patterns, targets, dim);
end
% 取最小基尼系数,返回index
D     = find(Gini == min(Gini));
index = Nt(D(1));

end

function delta = CARTfunctions(split_point, patterns, targets, dim)

%Calculate the difference in impurity for the CART algorithm
Uc	= unique(targets);
% D1,D2
for i = 1:length(Uc),
   in		= find(patterns(dim,:) == split_point);
   out      = find(patterns(dim,:) ~= split_point);
   if numel(in) == 0,
       Pl(i) = 0 ;
   else
       Pl(i) = length(find(targets(in)  ==  Uc(i)))/length(in);
   end   
   if numel(out) == 0,
       Pr(i) = 0 ;
   else    
       Pr(i) = length(find(targets(out) ==  Uc(i)))/length(out);
   end  
end

%Gini=1-sum(pi(n)^2)
Er		= 1 - sum(Pr.^2);
El		= 1 - sum(Pl.^2); 

P		= length(find(patterns(dim, :) == split_point)) / length(targets);
delta   = P*El + (1-P)*Er;





<pre name="code" class="python">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值