多类线性分类器算法原理及代码实现 MATLAB

多类线性分类器算法原理及代码实现 MATLAB

一、算法原理

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
下面举例说明为何蓝圈部分在case2中是确定的而在case1中不确定:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、代码实现

1、HK函数

function [] = HK(w1_data,w2_data)
%w1_data为第一类数据集  w2_data为第二类数据集
%此函数的作用为用HK算法对输入的数据集w1_data,w2_data做二分类,并画出分界面
 
lr=0.5 ;%学习率
MaxIter = ceil(20000/lr);%最大迭代次数
Eps = 1e-5; %精度

%% 增广化
w1_data=[w1_data';ones(1,size(w1_data,1))]';
w2_data=[w2_data';ones(1,size(w2_data,1))]';

Y=[w1_data;-w2_data];%待分类数据,w1在决策平面的正侧
[xmin,~]=min(Y,[],1); %求出横坐标最小值
[xmax,~]=max(Y,[],1); %求出横坐标最大值
b=rand(size(Y,1),1);
Y_flag=(Y'*Y)\Y';
% Y_flag=pinv(Y);

N=length(b);
C=0;%迭代次数
while(C < MaxIter)
    a=Y_flag*b;
    e=Y*a-b;
    zeronum = sum(e<Eps & e>-Eps);
    nenum = sum(e<0);
    if  zeronum==N %all is 0
        break;
    elseif nenum ==N
        break;
    end
    delta=lr*(e+abs(e));
    b = b + delta; %更新b
    C=C+1; %迭代次数+1
end
if C ==MaxIter
    if sum(e>-Eps)==N % all is larger than or equal to 0.
        fprintf('It has cost all iterartions(%d), and all elements are larger than or equal to 0. The sample is linear to be classified!\n',MaxIter);
    else
        if sum(e<=Eps) ==N% all is less than or equal to 0.
            fprintf('It has cost all iterartions(%d), and all elements are less than or equal to 0. The sample is non-linear to be classified!\n',MaxIter);
        else
            fprintf('It has cost all iterartions(%d), the sample is uncertain to be classified!\n',MaxIter);
        end
    end
end


%% 画出分界面
x1=(xmin-8:0.1:xmax+20);
if abs(a(2))<1e-7
    x1=-a(3)/a(1);
    x2=(-1:0.1:1);
    x1=ones(size(x2))*x1;
else
    x2=(a(1)*x1+a(3))/(-a(2));
end
plot(x1,x2,'LineWidth',1);
hold on;


2、case1

clc;
close all;
clear;
%% 生成数据
rng(2020);   %指定一个种子
mu1 = [0 3];
sigma1 = [0.5 0; 
         0 0.5];
data1 = mvnrnd(mu1,sigma1,300); %生成一个300*2的矩阵,每一列的数据分别以03为均值,标准差都为0.5

rng(2021);  %指定一个种子
mu2 = [6 7];
sigma2 = [0.5 0; 
         0 0.5];
data2 = mvnrnd(mu2,sigma2,300); %生成一个300*2的矩阵,每一列的数据分别以67为均值,标准差都为0.5

rng(2022);  %指定一个种子
mu3 = [5 -5];
sigma3 = [0.5 0; 
         0 0.5];
data3 = mvnrnd(mu3,sigma3,300); %生成一个300*2的矩阵,每一列的数据分别以5-5为均值,标准差都为0.5

HK(data1,[data3;data2]); %data1为一类,其他所有数据为另一类
HK(data2,[data1;data3]); %data2为一类,其他所有数据为另一类
HK(data3,[data1;data2]); %data3为一类,其他所有数据为另一类

%% 画出点集
plot(data1(:,1),data1(:,2),'r+');hold on;
plot(data2(:,1),data2(:,2),'b*');hold on;
plot(data3(:,1),data3(:,2),'m^');hold on;

实验结果:

在这里插入图片描述

3、case2

clc;
close all;
clear;
%% 生成数据
rng(2020);   %指定一个种子
mu1 = [0 3];
sigma1 = [0.5 0; 
         0 0.5];
data1 = mvnrnd(mu1,sigma1,300); %生成一个300*2的矩阵,每一列的数据分别以03为均值,标准差都为0.5

rng(2021);  %指定一个种子
mu2 = [6 7];
sigma2 = [0.5 0; 
         0 0.5];
data2 = mvnrnd(mu2,sigma2,300); %生成一个300*2的矩阵,每一列的数据分别以67为均值,标准差都为0.5

rng(2022);  %指定一个种子
mu3 = [5 -5];
sigma3 = [0.5 0; 
         0 0.5];
data3 = mvnrnd(mu3,sigma3,300); %生成一个300*2的矩阵,每一列的数据分别以5-5为均值,标准差都为0.5

HK(data1,data2); %对data1,data2作二分类
HK(data2,data3); %对data2,data3作二分类
HK(data3,data1); %对data1,data3作二分类

%% 画出点集
plot(data1(:,1),data1(:,2),'r+');hold on;
plot(data2(:,1),data2(:,2),'b*');hold on;
plot(data3(:,1),data3(:,2),'m^');hold on;

实验结果:

在这里插入图片描述

4、case3

clc;
close all;
clear;


%% 生成数据
rng(1800);   %指定一个种子
mu1 = [0 3];
sigma1 = [0.5 0; 
         0 0.5];
data1 = mvnrnd(mu1,sigma1,300); %生成一个300*2的矩阵,每一列的数据分别以03为均值,标准差都为0.5

rng(1900);  %指定一个种子
mu2 = [6 7];
sigma2 = [0.5 0; 
         0 0.5];
data2 = mvnrnd(mu2,sigma2,300); %生成一个300*2的矩阵,每一列的数据分别以67为均值,标准差都为0.5

rng(2022);  %指定一个种子
mu3 = [5 -5];
sigma3 = [0.5 0; 
         0 0.5];
data3 = mvnrnd(mu3,sigma3,300); %生成一个300*2的矩阵,每一列的数据分别以5-5为均值,标准差都为0.5

%% 
Label1=ones(length(data1),1);   %为data1的每个数据生成一个标签
Label2=ones(length(data2),1)+1; %为data2的每个数据生成一个标签
Label3=ones(length(data3),1)+2; %为data3的每个数据生成一个标签
Data=[data1;data2;data3]; %将三个数据集整合起来
Label=[Label1;Label2;Label3]; %将三个标签集整合起来
[xmin,ymin]=min(Data,[],1); %提取出数据集横坐标最小值,纵坐标最小值
[xmax,ymax]=max(Data,[],1); %提取出数据集横坐标最大值,纵坐标最大值
Data=[Data,ones(size(Data,1),1)]; %为每个数据增加一维,增加的一维取值为1

%%
[N,M]=size(Data);
A=randn(M,3); %随机初始化三类的三个权向量 每一列是一个权向量
p=1; %学习率
t=0; %迭代器
MaxInt=1000; %最大迭代次数
while(t<MaxInt)
    C=0; %分类正确计算器
    for i=1:N
        y=Data(i,:)'; %提取出第i个数据
        tmp=A'*y; %计算数据在三个判别器中的值
        [v,ind]=max(tmp); %提取出最大判别器的值和序号
        if ind==Label(i) % 如果最大判别器就是该类的判别器
            C=C+1; %该数据正确分类
        else
            A(:,ind)=A(:,ind)-p*y;
            A(:,Label(i))=A(:,Label(i))+p*y; %套用公式,更新相应的权向量
        end
    end
    t=t+1; %迭代次数+1
    if C==N %如果样本全部正确分类,则退出循环
        break;
    end
end

%% 求交点
A_=A(1:2,:)';
b_=-A(3,:)';
pt=(A_'*A_)\A_'*b_; %求一个向量,令三个判别器的值全为零,该向量就是交点
                    %注意A的每一列是增广的权向量,真正的权是前两行,第三行其实是w0
                    
                    
%%
w1=A(:,1)-A(:,2);
w2=A(:,1)-A(:,3);
w3=A(:,2)-A(:,3);
X1=xmin-30:0.1:pt(1);
X2=pt(1):0.1:xmax+30;
Y1=(-w1(1)*X2-w1(3))/(w1(2));
Y2=(-w2(1)*X1-w2(3))/(w2(2));
Y3=(-w3(1)*X2-w3(3))/(w3(2));

%% 画出三个数据集的点
plot(data1(:,1),data1(:,2),'r+');hold on;
plot(data2(:,1),data2(:,2),'b*');hold on;
plot(data3(:,1),data3(:,2),'m^');hold on;

%% 画出三个分类平面
plot(X2,Y1,'k-');hold on;
plot(X1,Y2,'k-.');hold on;
plot(X2,Y3,'k--');hold on;

%% 画出交点
plot(pt(1),pt(2),'.','MarkerSize',24);
axis equal;

实验结果:

在这里插入图片描述

  • 4
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

果壳小旋子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值