机器学习技法作业三题目13-15

说明:

1. 此处是台大林轩田老师主页上的hw7,对应coursera上“机器学习技法”作业三;

2. 本文给出大作业(13-15题)的代码;

3. Matlab代码;

4. 非职业码农,代码质量不高,变量命名也不规范,凑合着看吧,不好意思;

5. 如有问题,欢迎指教,QQ:50834。


题目13-15,分为主程序和四个函数

主程序:

clear all;
close all;

global dtree_node;

data_trn = csvread('hw7_train.dat');
data_tst = csvread('hw7_test.dat');

xtrn = data_trn(:,1:end-1);
ytrn = data_trn(:,end);
[N,k] = size(xtrn);
xtst = data_tst(:,1:end-1);
ytst = data_tst(:,end);
[Ntst,k] = size(xtst);

clear data_trn
clear data_tst

dtree_node = [];
node0 = [1,0,0,0,0,0,0];
% num, father, sign, dim, th, son +, son -
% sign, dim, th, son+, son- 待修改
hw7_CART_train(xtrn,ytrn, node0);

fprintf('=====================================================================\n');
% fprintf(' Node# | Fth.Node | Sign | Deci.Dim | Thrshhld | SonNode+ | SonNode-\n');
% for node = 1:length(dtree_node),
%     fprintf('%4d %9d %8d %8d %13.4f %7d %10d\n', dtree_node(node,:));
% end;
% fprintf('---------------------------------------------------------------------\n');

fprintf('Decision tree for classification\n');
for node = 1:length(dtree_node),
    fprintf('%3d  ', dtree_node(node,1));
    if dtree_node(node,4) == 0,
        fprintf('Leaf node with class = %2d\n', dtree_node(node,3));
    else
        fprintf('if x%d>=%6.4f then node %d else node %d\n', dtree_node(node,4), ...
            dtree_node(node,5), dtree_node(node,6), dtree_node(node,7));
    end;
end;
fprintf('---------------------------------------------------------------------\n');

ypred = hw7_CART_pred(dtree_node, xtrn);
fprintf(' Ein  = %6.2f%% \n', sum(ypred~=ytrn)/N*100);
fprintf('---------------------------------------------------------------------\n');
figure;
hold on;
idxp = ytrn>0;
idxm = ytrn<0;
plot(xtrn(idxp,1),xtrn(idxp,2),'bo');
plot(xtrn(idxm,1),xtrn(idxm,2),'ro');
idxp = ypred>0;
idxm = ypred<0;
plot(xtrn(idxp,1),xtrn(idxp,2),'b*');
plot(xtrn(idxm,1),xtrn(idxm,2),'r*');

ypred = hw7_CART_pred(dtree_node, xtst);
fprintf(' Eout = %6.2f%% \n', sum(ypred~=ytst)/Ntst*100);
fprintf('---------------------------------------------------------------------\n');

figure;
hold on;
idxp = ytst>0;
idxm = ytst<0;
plot(xtst(idxp,1),xtst(idxp,2),'bo');
plot(xtst(idxm,1),xtst(idxm,2),'ro');
idxp = ypred>0;
idxm = ypred<0;
plot(xtst(idxp,1),xtst(idxp,2),'b*');
plot(xtst(idxm,1),xtst(idxm,2),'r*');

函数1:

function gini = hw7_gini(y)

y_uni = unique(y);
y_num = length(y_uni);
N = length(y);

gini = 1;

for i = 1:y_num,
    gini = gini - (sum(y==y_uni(i))/N)^2;
end;
end

函数2:

function [s, dim, thresh] = hw7_deci_stump_impurity(x, y)

[N,k]=size(x);

bx = hw7_gini(y);
thresh = -Inf;
s = sign(sum(y));
dim = 1;

for feat = 1:k,
    [xsort, idxsort] = sort(x(:,feat));
    for rec = 1:N-1,
        N1 = rec;
        bx1 = hw7_gini(y(idxsort(1:rec)));
        N2 = N-rec;
        bx2 = hw7_gini(y(idxsort(rec+1:end)));
        bx_tmp = (N1*bx1+N2*bx2)/N;
        if bx_tmp < bx,
            bx = bx_tmp;
            thresh = (xsort(rec)+xsort(rec+1))/2;
            s = sign(sum(y(idxsort(rec+1:end))==1)/N2-sum(y(idxsort(1:rec))==1)/N1);
            if s==0
                s=1;
            end;
            dim = feat;
        end;
    end;
end;

end

函数3:

function hw7_CART_train(x,y,node0)

global dtree_node;
dtree_node = [dtree_node;node0];
fn_num = node0(1);

y_uniq = unique(y);
y_num = length(y_uniq);

if y_num ~= 1,
    [s, dim, thresh] = hw7_deci_stump_impurity(x, y);
    dtree_node(fn_num,3)=s;
    dtree_node(fn_num,4)=dim;
    dtree_node(fn_num,5)=thresh;
    
    idxp = x(:,dim)>=thresh;
    idxm = x(:,dim)<thresh;
    
    dtree_size = size(dtree_node);
    nodep = [dtree_size(1)+1,fn_num,0,0,0,0,0];
    dtree_node(fn_num,6)=dtree_size(1)+1;
    hw7_CART_train(x(idxp,:),y(idxp),nodep);
    
    dtree_size = size(dtree_node);
    nodem = [dtree_size(1)+1,fn_num,0,0,0,0,0];
    dtree_node(fn_num,7)=dtree_size(1)+1;
    hw7_CART_train(x(idxm,:),y(idxm),nodem);
else
    dtree_node(fn_num,3)=y_uniq;
end;

函数4:

function y = hw7_CART_pred(dtree, x)

[N,k] = size(x);
y = zeros(N,1);

for i = 1:N,
    %y(i) = hw7_CART_1pred(dtree,x(i,:));
    next_node = 1;
    while dtree(next_node,4)>1e-10,
        if x(i,dtree(next_node,4))>=dtree(next_node,5)
            next_node = dtree(next_node,6);
        else
            next_node = dtree(next_node,7);
        end;
    end;
    y(i) = dtree(next_node,3);
end;

运行结果:

=====================================================================
Decision tree for classification
  1  if x2>=0.6262 then node 2 else node 5
  2  if x1>=0.8782 then node 3 else node 4
  3  Leaf node with class =  1
  4  Leaf node with class = -1
  5  if x1>=0.2244 then node 6 else node 19
  6  if x1>=0.5415 then node 7 else node 12
  7  if x2>=0.2859 then node 8 else node 9
  8  Leaf node with class =  1
  9  if x2>=0.2660 then node 10 else node 11
 10  Leaf node with class = -1
 11  Leaf node with class =  1
 12  if x2>=0.3586 then node 13 else node 16
 13  if x1>=0.2608 then node 14 else node 15
 14  Leaf node with class = -1
 15  Leaf node with class =  1
 16  if x1>=0.5016 then node 17 else node 18
 17  Leaf node with class = -1
 18  Leaf node with class =  1
 19  if x2>=0.1152 then node 20 else node 21
 20  Leaf node with class = -1
 21  Leaf node with class =  1
---------------------------------------------------------------------
 Ein  =   0.00% 
---------------------------------------------------------------------
 Eout =  12.60% 
---------------------------------------------------------------------



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值