实验四、近邻法分类器
实验目的
本实验旨在让同学理解近邻法的原理,通过软件编程分段线性分类器的极端情况,理解 k-近邻法和剪辑近邻的设计过程,掌握影响 k-近邻法错误率的估算因素等。
实验条件
MATLAB 软件、 PC 机器
实验原理
最近邻法可以扩展成找测试样本的 k 个最近样本作决策依据的方法。其基本规则是,在所有 N 个样本中找到与测试样本的 k 个最近邻者,其中各类别所占个数表示成
k
i
,
i
=
1
,
.
.
.
,
c
k_i,i=1,...,c
ki,i=1,...,c,则决策规划是:
k
j
(
X
)
=
max
i
k
i
(
X
)
,
i
=
1
,
.
.
.
,
c
→
X
∈
w
j
{k_j}(X) = \mathop {\max }\limits_i {k_i}(X),i = 1,...,c \to X \in {w_j}
kj(X)=imaxki(X),i=1,...,c→X∈wj
k 近邻一般采用 k 为奇数,跟投票表决一样,避免因两种票数相等而难以决策。
剪辑近邻法的基本思想是从这样一个现象出发的,即当不同类别的样本在分布上有交迭部分的,分类的错误率主要来自处于交迭区中的样本。当我们得到一个作为识别用的参考样本集时,由于不同类别交迭区域中不同类别的样本彼此穿插,导致用近邻法分类出错。因此如果能将不同类别交界处的样本以适当方式筛选,可以实现既减少样本数又提高正确识别率的双重目的。为此可以利用现有样本集对其自身进行剪辑。下面以两类别问题为例说明这种方法的原理。
假设现有一个样本集 N,样本数量为 N。我们将此样本集分成两个互相独立的样本子集。一个被当作考试集 a N T a^{NT} aNT ,另一个作为参考集 a N R a^{NR} aNR ,数量分别为 N T N_T NT与 N R N_R NR , N T N_T NT + N R N_R NR =N。将 a N T a^{NT} aNT中的样本表示成 X i , i = 1 , . . . , N T X_i,i=1,...,N_T Xi,i=1,...,NT,而在 a N R a^{NR} aNR 中的样本表示为 Y i , j = 1 , . . . , N R Y_i,j=1,...,N_R Yi,j=1,...,NR。
将一个样本集分成两个相互独立的样本子集是指,分完以后的两个子集具有相同的分布例如将一个样本集分成两个相互独立的对等子集,则在每个特征空间的子区域,两个子集都有相同的比例,或说各类数量近似相等。 要注意指出的是每个子区域(从大空间到小空间)实际做时要用从总的集合中随机抽取的方式进行。
剪辑的过程是: 首先对 a N T a^{NT} aNT中每一个 X i X_i Xi在 a N R a^{NR} aNR 中找到其最近邻的样本 Y i ( X i ) Y_i(X_i) Yi(Xi),用 Y i ( X i ) Y_i(X_i) Yi(Xi)表示 Y i Y_i Yi是 X i X_i Xi的最近邻参考样本。如果 Y i Y_i Yi与 X i X_i Xi不属于同一类别,则将 X i X_i Xi从 a N T a^{NT} aNT 中删除,最后从 a N T a^{NT} aNT中得到一个经过剪辑的样本集,称为剪辑样本集 a N T E a^{NTE} aNTE 。 a N T E a^{NTE} aNTE可用来取代原样本集 a N a_N aN,作为参考样本集对待识别样本进行分类.
a 经过剪辑后,要作为新的训练样本集,则 a N R a_{NR} aNR是对其性能进行测试的样本,如发现 a N T a_{NT} aNT中的某个训练样本对分类不利,就要把它剪辑掉。
实际上剪辑样本的过程也可以用 k-近邻法进行,即对 a N T a_NT aNT 中的每个样本 X i X_i Xi,找到在 a N R a_{NR} aNR 中的 k 个近邻,用 k-近邻法判断 X i X_i Xi是否被错分类。从而决定其取舍,其它过程与前述方法完全一样。
剪辑近邻法也可用到多类别情况。剪辑过程也可不止一次。重复多次的称为重复剪辑近邻法。
示例代码
function [index_cluster,cluster] = kmeans_func(data,cluster_num)
%% 原理推导Kmeans聚类算法
[m,n]=size(data);
cluster=data(randperm(m,cluster_num),:);%从m个点中随机选择cluster_num个点作为初始聚类中心点
epoch_max=1000;%最大次数
therad_lim=0.001;%中心变化阈值
epoch_num=0;
while(epoch_num<epoch_max)
epoch_num=epoch_num+1;
% distance1存储每个点到各聚类中心的欧氏距离
for i=1:cluster_num
distance=(data-repmat(cluster(i,:),m,1)).^2;
distance1(:,i)=sqrt(sum(distance'));
end
[~,index_cluster]=min(distance1');%index_cluster取值范围1~cluster_num
% cluster_new存储新的聚类中心
for j=1:cluster_num
cluster_new(j,:)=mean(data(find(index_cluster==j),:));
end
%如果新的聚类中心和上一轮的聚类中心距离和大于therad_lim,更新聚类中心,否则算法结束
if (sqrt(sum((cluster_new-cluster).^2))>therad_lim)
cluster=cluster_new;
else
break;
end
end
end
clc;clear;close all;
data(:,1)=[90,35,52,83,64,24,49,92,99,45,19,38,1,71,56,97,63,...
32,3,34,33,55,75,84,53,15,88,66,41,51,39,78,67,65,25,40,77,...
13,69,29,14,54,87,47,44,58,8,68,81,31];
data(:,2)=[33,71,62,34,49,48,46,69,56,59,28,14,55,41,39,...
78,23,99,68,30,87,85,43,88,2,47,50,77,22,76,94,11,80,...
51,6,7,72,36,90,96,44,61,70,60,75,74,63,40,81,4];
figure(1)
scatter(data(:,1),data(:,2),'LineWidth',2)
title("原始数据散点图")
cluster_num=3;
[index_cluster,cluster] = kmeans_func(data,cluster_num);
%% 画出聚类效果
figure(2)
% subplot(2,1,1)
a=unique(index_cluster); %找出分类出的个数
C=cell(1,length(a));
for i=1:length(a)
C(1,i)={find(index_cluster==a(i))};
end
for j=1:cluster_num
data_get=data(C{1,j},:);
scatter(data_get(:,1),data_get(:,2),100,'filled','MarkerFaceAlpha',.6,'MarkerEdgeAlpha',.9);
hold on
end
%绘制聚类中心
plot(cluster(:,1),cluster(:,2),'ks','LineWidth',2);
hold on
sc_t=mean(silhouette(data,index_cluster'));
title_str=['原理推导K均值聚类',' 聚类数为:',num2str(cluster_num),' SC轮廓系数:',num2str(sc_t)];
title(title_str)
% data set;
Sigma = [1, 0; 0, 1];
mu1 = [1, -1];
x1 = mvnrnd(mu1, Sigma, 200);
mu2 = [5, -4];
x2 = mvnrnd(mu2, Sigma, 200);
mu3 = [1, 4];
x3 = mvnrnd(mu3, Sigma, 200);
mu4 = [6, 4];
x4 = mvnrnd(mu4, Sigma, 200);
mu5 = [7, 0.0];
x5 = mvnrnd(mu5, Sigma, 200);
X = [x1; x2; x3; x4; x5];
X_label = [ones(200, 1); 2 * ones(200, 1); 3 * ones(200,1); 4 * ones(200, 1);5 * ones(200, 1)];
% Show the data points
plot(x1(:,1), x1(:,2), 'r.'); hold on;
plot(x2(:,1), x2(:,2), 'b.');
plot(x3(:,1), x3(:,2), 'k.');
plot(x4(:,1), x4(:,2), 'g.');
plot(x5(:,1), x5(:,2), 'm.');
% select initial clustering center
m = 30;
a = max(X);
b = min(X);
k=5;
mu = zeros(k,2*m);
r = zeros(m,1);
for t=1:m
for i=1:k
mu(i,2*t-1:2*t)=[a(1)+(b(1)-a(1))*rand,a(2)+(b(2)-a(2))*rand];
end
for j = 1 : 1000
R = repmat(X(j, :), k, 1) - mu(:,2*t-1:2*t);
r(t) = r(t) + sum(sum(R.*R));
end
end
p = find(r==min(r));
mu = mu(:,2*p-1:2*p);
label = zeros(1000, 1);
mu_new = mu;
eps = 1e-6;
delta = 1;
while (delta > eps)
mu = mu_new;
for i =1:1000
y = repmat (X(i, :), k, 1);
dist = y - mu;
d = sum(dist.*dist,2);
j = find(d==min(d));
label(i) = j;
end
for j = 1 : k
order = find(label == j);
mu_new(j, :) = mean(X(order, :), 1);
end
delta = sqrt(sum(sum((mu-mu_new).*(mu-mu_new))));
end
label = zeros(1000, 1);
for i = 1 : 1000
R = repmat(X(i,:),k,1) - mu;
Residual = sum(R.*R,2);
j = find(Residual == min(Residual));
label(i) = j;
end
% Construct map function
s = zeros(k, 1);
for j =1 : k
order = find(label==j);
Y = X_label(order);
s(j) = mode(Y);
end
map_label =zeros(1000, 1);
for j = 1 : k
map_label(label==j) = s(j);
end
figure;
hold on;
for i =1:1000
if map_label(i)==1
plot(X(i,1),X(i,2),'r.');
elseif map_label(i)==2
plot(X(i,1),X(i,2),'b.');
elseif map_label(i)==3
plot(X(i,1),X(i,2),'k.');
elseif map_label(i)==4
plot(X(i,1),X(i,2),'g.');
else
plot(X(i,1),X(i,2),'m.');
end
end
% show the cluster center
for i = 1 : 5
plot(mu(i,1),mu(i,2),'yo','LineWidth',3);
end
% Calculate NMI(Normalized Mutual Information)
d = zeros(5, 1);
g = d;
sigma = zeros(5,5);
numerator = 0;
denominator1 = 0;
denominator2 = 0;
for i = 1 : 5
d(i) = length(find(map_label==i));
g(i) = length(find(X_label==i));
end
for i = 1 : 5
for j = 1 : 5
order = find(map_label==i);
sigma(i,j) = length(find(X_label(order)==j));
if sigma(i,j)~=0
numerator = numerator + sigma(i,j).*log(1000.*sigma(i,j)./(d(i).*g(j)));
end
end
end
for i = 1 : 5
if d(i)~=0
denominator1 = denominator1 + d(i).*log(d(i)/1000);
end
if g(i)~=0
denominator2 = denominator2 + g(i).*log(g(i)/1000);
end
end
denominator = sqrt(denominator1 * denominator2);
NMI = numerator/denominator;
fprintf('NMI=%.3f\n',NMI);
accuracy = sum(map_label == X_label)/1000;
fprintf('accuracy=%.3f\n',accuracy);