感知器算法实现多类样本的线性分类(Matlab)

原理:略。

步骤:

二分类问题:

(1)将第一类样本作为正样本,第二类样本作为负样本。首先,对样本的向量空间进行增广,即对n维向量x的首部或者尾部增加一个参数1,增广为(n+1)维向量,并对其进行规范化,即正样本不做处理,负样本的(n+1)维向量取负。

(2)定义一个(n+1)维权向量w,并进行初始化,定义学习步长LearnRate。

(3)进行迭代,对于每个样本,如果w与x的转置的乘积大于0,则不做处理,否则更新权向量:

w=w+LearnRate*x

直到对所有样本的w与x的转置的乘积大于0,退出迭代。

(4)最终得到的w即最终的权向量,得到直线w1+w2*x1+w3*x3+....=0(增广的参数1在(n+1)维向量首部时)。

多分类问题:

(1)同样对所有样本进行增广,但不用进行规范化。

(2)定义k个(n+1)维权向量,k为类别数,并进行初始化,定义学习步长LearnRate。

(3)迭代,如果第i类样本j存在wi*xj'<=wt*xj',其中t为非i类,则进行如下操作:

wi=wi+LearnRate*xj

wt=wt-LearnRate*xj

直到所有wi*xj'>wt*xj',退出迭代。

(4)得到k组权向量,wi-wk=0为第i类和k类样本的分界线。

二维多分类问题代码:

clear
clc
n=6;%样本点个数
class=4;%类别数
pattern=[1  2  -1   0  -1  2;1  1 -1  -1  1  -1]; %每一列为一个样本
Class=[1  1  2  2  3  4]; %类别
LearnRate=0.2;
PlotPats(pattern,Class-1); %绘制样本点
input=[ones(1,n);pattern]'; %每一行为一个规范化后的样本
 w=zeros(class,3);
for i=1:30 %迭代次数
    break_time=0; %当值为0的时候表明本次迭代没有更新,退出迭代
     for j=1:n
         j_class=Class(j);
         answer=zeros(1,class);
         for k=1:class %对于每一个样本对于所有类计算w*x'
             answer(k)=w(k,:)*(input(j,:))';
         end
         ret=0;
         for l=1:class %如果某j_class类样本的answer小于该样本与其他类的anwser ret记为1
             if(l~=j_class)&&(answer(j_class)<=answer(l))
                 ret=1;
             end
         end


         if ret==1
             break_time=break_time+1;
             for m=1:class%同时改变所有权值
                 if m~=j_class
                     w(m,:)=w(m,:)-LearnRate*input(j,:);
                 else
                     w(m,:)=w(m,:)+LearnRate*input(j,:);
                 end
             end
         end
     end
     if break_time==0
         break
     end
end


             
newclass=class*(class-1)/2;
new=zeros(newclass,3);
t=0;
for i=1:class%计算每个分界线
    for j=i+1:class
        t=t+1;
        new(t,:)=w(i,:)-w(j,:)
    end
end


    for q=1:newclass 
        PlotBoundary([new(q,:)] ,i,1)%绘制分界线
    end                    
    drawnow
        

其中绘制样本点与绘制分界线的函数PlotPats.m与PlotBoundary.m如下:

function PlotPats(P,D)
% PLOTPATS   Plots the training patterns defined by Patterns and Desired.
%
%            P - NELTS x NPATS matrix of input patterns (column vectors).
%		 The first two values in each pattern are used 
%		 as the coordinates of the point to be plotted.
%
%            D - NUNITS x NPATS matrix of desired binary output patterns.
%		 The first 2 bits of the output pattern determine the
%		 class of the point: o, +, *, or x.

[NELTS,NPATS] = size(P);
NUNITS = size(D,1);

if NUNITS<2, D=[D;zeros(1,NPATS)]; end

colordef none
clf reset, whitebg(gcf,[0.82 0.82 0.82])
hold on, box on

% Calculate the bounds for the plot and cause axes to be drawn.
xmin = min(P(1,:)); xmax = max(P(1,:)); xb = (xmax-xmin)*0.2;
ymin = min(P(2,:)); ymax = max(P(2,:)); yb = (ymax-ymin)*0.2;
axis([xmin-xb, xmax+xb, ymin-yb, ymax+yb]);

title('Input Classification');
xlabel('x1'); ylabel('x2');

class = 1 + D(1,:) + 2*D(2,:);
colors = [1 0 1; 1 1 0; 0 1 1; 0 1 0];
symbols = 'o+*x';

for i=1:NPATS
  c = class(i);
  plot(P(1,i),P(2,i),symbols(c),'Color',colors(c,:),'LineWidth',3);
  end
function PlotBoundary(W,iter,done)

colors = jet;

if ~done
   lstyle = '--';
   color = colors(1+rem(3*iter+9,size(colors,1)),:);
else
   lstyle = '-';
   color = [1 1 1];
end

d = W(3);
if abs(d) < 0.001, d = 0.001; end

plot([-2 2],(-W(2)*[-2 2]-W(1))/d,'LineStyle',lstyle,'Color',color,'LineWidth',2);
drawnow

运行结果:



三维二分类问题代码:

pattern=[5 2 3 12 30 14;7 3 4 10 12 18;8 5 6 10 36 14];
Desired=[0 0 0 1 1 1];
PlotPats3D(pattern,Desired);
[m n]=size(Desired);
w = [0 0 0 0];
 input=[ones(1,n);pattern]';
 for i=1:n
     if Desired(i)==1
         input(i,:)=-input(i,:);
     end
 end
 learnrate=0.8;
 for i=1:50
     error=0;
     for i=1:n
         if w*input(i,:)'<=0
             error=error+1;
             w=w+learnrate*input(i,:);
         end
     end
     if error==0
          break
     end
 end
X=-50:0.5:50;
Y=-50:0.5:50;
[X Y]=meshgrid(X,Y);
Z=-(w(1)+w(2)*X+w(3)*Y)/w(4);
surf(X,Y,Z);

其中绘制三维样本点函数PlotPats3D.m与make3views.m如下:

function PlotPats3D(P,D)  
  
  colordef none, clf reset  
  make3view  
  maxx=max(P')+10
  minx=min(P')-10
  axis([ minx(1) maxx(1) minx(2) maxx(2) minx(3) maxx(3)])  
  view(72,24)  
    
[m,n]=size(P);
for i=1:n
    if D(i)==1
        plot3(P(1,i),P(2,i),P(3,i),'y+')  
    elseif D(i)==2
        plot3(P(1,i),P(2,i),P(3,i),'mo')   
    end
end
function make3view


cla
axis([-1 1 -1 1 -1 1]), grid on, box on, hold on
xlabel('x1'), ylabel('x2'), zlabel('x3')
set(gca,'CameraViewAngleMode','manual')
rotate3d on
colormap jet
caxis([-1 1])

运行结果如下:




三维多分类问题代码:

n=6;
class=4;
pattern=[1 2  7 19 20 26;2 1 10 22 19 26;3 2 5 16 17 18];
Class=[1 1 2 3 4 4];
PlotPats3Dmult(pattern,Class);
LearnRate=0.5;
input=[ones(1,n);pattern]';
w=zeros(class,4);
for i=1:class
    w(i,:)=[1 -1 -1 -1];
end

 for i=1:3000
    break_time=0;
     for j=1:n
         j_class=Class(j);
         answer=zeros(1,class);
         for k=1:class
             answer(k)=w(k,:)*(input(j,:))';
         end
         ret=0;
         for l=1:class
             if(l~=j_class)&&(answer(j_class)<=answer(l))
                 ret=1;
             end
         end

         if ret==1
             break_time=break_time+1;
             for m=1:class
                 if m~=j_class
                     w(m,:)=w(m,:)-LearnRate*input(j,:);
                 else
                     w(m,:)=w(m,:)+LearnRate*input(j,:);
                 end
             end
         end
     end
     if break_time==0
         break
     end
 end
newclass=class*(class-1)/2;
new=zeros(newclass,4);
t=0;
for i=1:class
    for j=i+1:class
        t=t+1;
        new(t,:)=w(i,:)-w(j,:)
    end
end
 for i=1:newclass
    X=-50:0.5:50;
    Y=-50:0.5:50;
    [X Y]=meshgrid(X,Y);
    Z=-(new(i,1)+new(i,2)*X+new(i,3)*Y)/new(i,4);
    surf(X,Y,Z);
 end


其中绘制样本点函数如下:

function PlotPats3Dmult(P,D)

  colordef none, clf reset
  make3view
  maxplot=max(P');
  minplot=min(P');
  axis([minplot(1)-10 maxplot(1)+10 minplot(2)-10 maxplot(2)+10 minplot(3)-10 maxplot(3)+10])
  view(72,24)
  [m n]=size(P);
  for i=1:n
      if D(i)==1
            plot3(P(1,i),P(2,i),P(3,i),'y+')
      elseif D(i)==2
            plot3(P(1,i),P(2,i),P(3,i),'mo')
      elseif D(i)==3
            plot3(P(1,i),P(2,i),P(3,i),'rx')
      else
           plot3(P(1,i),P(2,i),P(3,i),'bo')
      end
  end
  


最后运行结果如下:




没有更多推荐了,返回首页