算法来自于李航老师的《统计机器学习》
程序使用MATLAB2017A的版本,其中用到了字符串,低版本无法运行,请将字符串改为字符数组
clear;clc;close all
%% 0. 导入数据
% 李航老师的《统计学习方法》第50页
data = ["1","1","1","1","1","2","2","2","2","2","3","3","3","3","3";
"S","M","M","S","S","S","M","M","L","L","L","M","M","L","L";
"-1","-1","1","1","-1","-1","-1","1","1","1","1","1","1","1","-1"];
X = num2cell(data(1:end - 1, :),2);
Y = data(end, :);
% 需要判断(2, 'S')的分类
X0 = ["2";"S"];
%% 1. 计算概率值,
P = containers.Map;
uY = unique(Y);
for y = uY
P(sprintf('P(Y=%s)',y)) = nnz(Y==y) / length(Y);
for i = 1 : length(X)
uX = unique(X{i});
for ux = uX
P(sprintf('P(X(%d)=%s|Y=%s)',i,ux,y)) = nnz(X{i} == ux & Y == y) / nnz(Y == y);
end
end
end
% 显示概率数值
fprintf('*****************************************\n')
for k = P.keys
fprintf('%s = %s\n', k{1}, rats(P(k{:})));
end
%% 2. 计算分类结果
rP = containers.Map;
for y = uY
rP(sprintf('P(X0=%s)',y)) = P(sprintf('P(Y=%s)',y));
for i = 1 : length(X0)
rP(sprintf('P(X0=%s)',y)) = rP(sprintf('P(X0=%s)',y)) * P(sprintf('P(X(%d)=%s|Y=%s)',i,X0(i),y));
end
end
% 显示计算结果
fprintf('*****************************************\n')
for k = rP.keys
fprintf('%s = %s\n', k{1}, rats(rP(k{:})));
end
%% 3. 选择概率最大的
fprintf('*****************************************\n')
[~, idx] = max(cell2mat(rP.values));
allKeys = rP.keys;
fprintf('%s\n', regexp(allKeys{idx}, '\(.*\)', 'match', 'once'));
计算结果如下: