说明:
1. 此处是台大林轩田老师主页上的hw7,对应coursera上“机器学习技法”作业三;
2. 本文给出大作业(13-15题)的代码;
3. Matlab代码;
4. 非职业码农,代码质量不高,变量命名也不规范,凑合着看吧,不好意思;
5. 如有问题,欢迎指教,QQ:50834。
题目13-15,分为主程序和四个函数
主程序:
clear all;
close all;
global dtree_node;
data_trn = csvread('hw7_train.dat');
data_tst = csvread('hw7_test.dat');
xtrn = data_trn(:,1:end-1);
ytrn = data_trn(:,end);
[N,k] = size(xtrn);
xtst = data_tst(:,1:end-1);
ytst = data_tst(:,end);
[Ntst,k] = size(xtst);
clear data_trn
clear data_tst
dtree_node = [];
node0 = [1,0,0,0,0,0,0];
% num, father, sign, dim, th, son +, son -
% sign, dim, th, son+, son- 待修改
hw7_CART_train(xtrn,ytrn, node0);
fprintf('=====================================================================\n');
% fprintf(' Node# | Fth.Node | Sign | Deci.Dim | Thrshhld | SonNode+ | SonNode-\n');
% for node = 1:length(dtree_node),
% fprintf('%4d %9d %8d %8d %13.4f %7d %10d\n', dtree_node(node,:));
% end;
% fprintf('---------------------------------------------------------------------\n');
fprintf('Decision tree for classification\n');
for node = 1:length(dtree_node),
fprintf('%3d ', dtree_node(node,1));
if dtree_node(node,4) == 0,
fprintf('Leaf node with class = %2d\n', dtree_node(node,3));
else
fprintf('if x%d>=%6.4f then node %d else node %d\n', dtree_node(node,4), ...
dtree_node(node,5), dtree_node(node,6), dtree_node(node,7));
end;
end;
fprintf('---------------------------------------------------------------------\n');
ypred = hw7_CART_pred(dtree_node, xtrn);
fprintf(' Ein = %6.2f%% \n', sum(ypred~=ytrn)/N*100);
fprintf('---------------------------------------------------------------------\n');
figure;
hold on;
idxp = ytrn>0;
idxm = ytrn<0;
plot(xtrn(idxp,1),xtrn(idxp,2),'bo');
plot(xtrn(idxm,1),xtrn(idxm,2),'ro');
idxp = ypred>0;
idxm = ypred<0;
plot(xtrn(idxp,1),xtrn(idxp,2),'b*');
plot(xtrn(idxm,1),xtrn(idxm,2),'r*');
ypred = hw7_CART_pred(dtree_node, xtst);
fprintf(' Eout = %6.2f%% \n', sum(ypred~=ytst)/Ntst*100);
fprintf('---------------------------------------------------------------------\n');
figure;
hold on;
idxp = ytst>0;
idxm = ytst<0;
plot(xtst(idxp,1),xtst(idxp,2),'bo');
plot(xtst(idxm,1),xtst(idxm,2),'ro');
idxp = ypred>0;
idxm = ypred<0;
plot(xtst(idxp,1),xtst(idxp,2),'b*');
plot(xtst(idxm,1),xtst(idxm,2),'r*');
函数1:
function gini = hw7_gini(y)
y_uni = unique(y);
y_num = length(y_uni);
N = length(y);
gini = 1;
for i = 1:y_num,
gini = gini - (sum(y==y_uni(i))/N)^2;
end;
end
函数2:
function [s, dim, thresh] = hw7_deci_stump_impurity(x, y)
[N,k]=size(x);
bx = hw7_gini(y);
thresh = -Inf;
s = sign(sum(y));
dim = 1;
for feat = 1:k,
[xsort, idxsort] = sort(x(:,feat));
for rec = 1:N-1,
N1 = rec;
bx1 = hw7_gini(y(idxsort(1:rec)));
N2 = N-rec;
bx2 = hw7_gini(y(idxsort(rec+1:end)));
bx_tmp = (N1*bx1+N2*bx2)/N;
if bx_tmp < bx,
bx = bx_tmp;
thresh = (xsort(rec)+xsort(rec+1))/2;
s = sign(sum(y(idxsort(rec+1:end))==1)/N2-sum(y(idxsort(1:rec))==1)/N1);
if s==0
s=1;
end;
dim = feat;
end;
end;
end;
end
函数3:
function hw7_CART_train(x,y,node0)
global dtree_node;
dtree_node = [dtree_node;node0];
fn_num = node0(1);
y_uniq = unique(y);
y_num = length(y_uniq);
if y_num ~= 1,
[s, dim, thresh] = hw7_deci_stump_impurity(x, y);
dtree_node(fn_num,3)=s;
dtree_node(fn_num,4)=dim;
dtree_node(fn_num,5)=thresh;
idxp = x(:,dim)>=thresh;
idxm = x(:,dim)<thresh;
dtree_size = size(dtree_node);
nodep = [dtree_size(1)+1,fn_num,0,0,0,0,0];
dtree_node(fn_num,6)=dtree_size(1)+1;
hw7_CART_train(x(idxp,:),y(idxp),nodep);
dtree_size = size(dtree_node);
nodem = [dtree_size(1)+1,fn_num,0,0,0,0,0];
dtree_node(fn_num,7)=dtree_size(1)+1;
hw7_CART_train(x(idxm,:),y(idxm),nodem);
else
dtree_node(fn_num,3)=y_uniq;
end;
函数4:
function y = hw7_CART_pred(dtree, x)
[N,k] = size(x);
y = zeros(N,1);
for i = 1:N,
%y(i) = hw7_CART_1pred(dtree,x(i,:));
next_node = 1;
while dtree(next_node,4)>1e-10,
if x(i,dtree(next_node,4))>=dtree(next_node,5)
next_node = dtree(next_node,6);
else
next_node = dtree(next_node,7);
end;
end;
y(i) = dtree(next_node,3);
end;
运行结果:
=====================================================================
Decision tree for classification
1 if x2>=0.6262 then node 2 else node 5
2 if x1>=0.8782 then node 3 else node 4
3 Leaf node with class = 1
4 Leaf node with class = -1
5 if x1>=0.2244 then node 6 else node 19
6 if x1>=0.5415 then node 7 else node 12
7 if x2>=0.2859 then node 8 else node 9
8 Leaf node with class = 1
9 if x2>=0.2660 then node 10 else node 11
10 Leaf node with class = -1
11 Leaf node with class = 1
12 if x2>=0.3586 then node 13 else node 16
13 if x1>=0.2608 then node 14 else node 15
14 Leaf node with class = -1
15 Leaf node with class = 1
16 if x1>=0.5016 then node 17 else node 18
17 Leaf node with class = -1
18 Leaf node with class = 1
19 if x2>=0.1152 then node 20 else node 21
20 Leaf node with class = -1
21 Leaf node with class = 1
---------------------------------------------------------------------
Ein = 0.00%
---------------------------------------------------------------------
Eout = 12.60%
---------------------------------------------------------------------