原理:略。
步骤:
二分类问题:
(1)将第一类样本作为正样本,第二类样本作为负样本。首先,对样本的向量空间进行增广,即对n维向量x的首部或者尾部增加一个参数1,增广为(n+1)维向量,并对其进行规范化,即正样本不做处理,负样本的(n+1)维向量取负。
(2)定义一个(n+1)维权向量w,并进行初始化,定义学习步长LearnRate。
(3)进行迭代,对于每个样本,如果w与x的转置的乘积大于0,则不做处理,否则更新权向量:
w=w+LearnRate*x
直到对所有样本的w与x的转置的乘积大于0,退出迭代。
(4)最终得到的w即最终的权向量,得到直线w1+w2*x1+w3*x3+....=0(增广的参数1在(n+1)维向量首部时)。
多分类问题:
(1)同样对所有样本进行增广,但不用进行规范化。
(2)定义k个(n+1)维权向量,k为类别数,并进行初始化,定义学习步长LearnRate。
(3)迭代,如果第i类样本j存在wi*xj'<=wt*xj',其中t为非i类,则进行如下操作:
wi=wi+LearnRate*xj
wt=wt-LearnRate*xj
直到所有wi*xj'>wt*xj',退出迭代。
(4)得到k组权向量,wi-wk=0为第i类和k类样本的分界线。
二维多分类问题代码:
clear
clc
n=6;%样本点个数
class=4;%类别数
pattern=[1 2 -1 0 -1 2;1 1 -1 -1 1 -1]; %每一列为一个样本
Class=[1 1 2 2 3 4]; %类别
LearnRate=0.2;
PlotPats(pattern,Class-1); %绘制样本点
input=[ones(1,n);pattern]'; %每一行为一个规范化后的样本
w=zeros(class,3);
for i=1:30 %迭代次数
break_time=0; %当值为0的时候表明本次迭代没有更新,退出迭代
for j=1:n
j_class=Class(j);
answer=zeros(1,class);
for k=1:class %对于每一个样本对于所有类计算w*x'
answer(k)=w(k,:)*(input(j,:))';
end
ret=0;
for l=1:class %如果某j_class类样本的answer小于该样本与其他类的anwser ret记为1
if(l~=j_class)&&(answer(j_class)<=answer(l))
ret=1;
end
end
if ret==1
break_time=break_time+1;
for m=1:class%同时改变所有权值
if m~=j_class
w(m,:)=w(m,:)-LearnRate*input(j,:);
else
w(m,:)=w(m,:)+LearnRate*input(j,:);
end
end
end
end
if break_time==0
break
end
end
newclass=class*(class-1)/2;
new=zeros(newclass,3);
t=0;
for i=1:class%计算每个分界线
for j=i+1:class
t=t+1;
new(t,:)=w(i,:)-w(j,:)
end
end
for q=1:newclass
PlotBoundary([new(q,:)] ,i,1)%绘制分界线
end
drawnow
其中绘制样本点与绘制分界线的函数PlotPats.m与PlotBoundary.m如下:
function PlotPats(P,D)
% PLOTPATS Plots the training patterns defined by Patterns and Desired.
%
% P - NELTS x NPATS matrix of input patterns (column vectors).
% The first two values in each pattern are used
% as the coordinates of the point to be plotted.
%
% D - NUNITS x NPATS matrix of desired binary output patterns.
% The first 2 bits of the output pattern determine the
% class of the point: o, +, *, or x.
[NELTS,NPATS] = size(P);
NUNITS = size(D,1);
if NUNITS<2, D=[D;zeros(1,NPATS)]; end
colordef none
clf reset, whitebg(gcf,[0.82 0.82 0.82])
hold on, box on
% Calculate the bounds for the plot and cause axes to be drawn.
xmin = min(P(1,:)); xmax = max(P(1,:)); xb = (xmax-xmin)*0.2;
ymin = min(P(2,:)); ymax = max(P(2,:)); yb = (ymax-ymin)*0.2;
axis([xmin-xb, xmax+xb, ymin-yb, ymax+yb]);
title('Input Classification');
xlabel('x1'); ylabel('x2');
class = 1 + D(1,:) + 2*D(2,:);
colors = [1 0 1; 1 1 0; 0 1 1; 0 1 0];
symbols = 'o+*x';
for i=1:NPATS
c = class(i);
plot(P(1,i),P(2,i),symbols(c),'Color',colors(c,:),'LineWidth',3);
end
function PlotBoundary(W,iter,done)
colors = jet;
if ~done
lstyle = '--';
color = colors(1+rem(3*iter+9,size(colors,1)),:);
else
lstyle = '-';
color = [1 1 1];
end
d = W(3);
if abs(d) < 0.001, d = 0.001; end
plot([-2 2],(-W(2)*[-2 2]-W(1))/d,'LineStyle',lstyle,'Color',color,'LineWidth',2);
drawnow
运行结果:
三维二分类问题代码:
pattern=[5 2 3 12 30 14;7 3 4 10 12 18;8 5 6 10 36 14];
Desired=[0 0 0 1 1 1];
PlotPats3D(pattern,Desired);
[m n]=size(Desired);
w = [0 0 0 0];
input=[ones(1,n);pattern]';
for i=1:n
if Desired(i)==1
input(i,:)=-input(i,:);
end
end
learnrate=0.8;
for i=1:50
error=0;
for i=1:n
if w*input(i,:)'<=0
error=error+1;
w=w+learnrate*input(i,:);
end
end
if error==0
break
end
end
X=-50:0.5:50;
Y=-50:0.5:50;
[X Y]=meshgrid(X,Y);
Z=-(w(1)+w(2)*X+w(3)*Y)/w(4);
surf(X,Y,Z);
其中绘制三维样本点函数PlotPats3D.m与make3views.m如下:
function PlotPats3D(P,D)
colordef none, clf reset
make3view
maxx=max(P')+10
minx=min(P')-10
axis([ minx(1) maxx(1) minx(2) maxx(2) minx(3) maxx(3)])
view(72,24)
[m,n]=size(P);
for i=1:n
if D(i)==1
plot3(P(1,i),P(2,i),P(3,i),'y+')
elseif D(i)==2
plot3(P(1,i),P(2,i),P(3,i),'mo')
end
end
function make3view
cla
axis([-1 1 -1 1 -1 1]), grid on, box on, hold on
xlabel('x1'), ylabel('x2'), zlabel('x3')
set(gca,'CameraViewAngleMode','manual')
rotate3d on
colormap jet
caxis([-1 1])
运行结果如下:
三维多分类问题代码:
n=6;
class=4;
pattern=[1 2 7 19 20 26;2 1 10 22 19 26;3 2 5 16 17 18];
Class=[1 1 2 3 4 4];
PlotPats3Dmult(pattern,Class);
LearnRate=0.5;
input=[ones(1,n);pattern]';
w=zeros(class,4);
for i=1:class
w(i,:)=[1 -1 -1 -1];
end
for i=1:3000
break_time=0;
for j=1:n
j_class=Class(j);
answer=zeros(1,class);
for k=1:class
answer(k)=w(k,:)*(input(j,:))';
end
ret=0;
for l=1:class
if(l~=j_class)&&(answer(j_class)<=answer(l))
ret=1;
end
end
if ret==1
break_time=break_time+1;
for m=1:class
if m~=j_class
w(m,:)=w(m,:)-LearnRate*input(j,:);
else
w(m,:)=w(m,:)+LearnRate*input(j,:);
end
end
end
end
if break_time==0
break
end
end
newclass=class*(class-1)/2;
new=zeros(newclass,4);
t=0;
for i=1:class
for j=i+1:class
t=t+1;
new(t,:)=w(i,:)-w(j,:)
end
end
for i=1:newclass
X=-50:0.5:50;
Y=-50:0.5:50;
[X Y]=meshgrid(X,Y);
Z=-(new(i,1)+new(i,2)*X+new(i,3)*Y)/new(i,4);
surf(X,Y,Z);
end
其中绘制样本点函数如下:
function PlotPats3Dmult(P,D)
colordef none, clf reset
make3view
maxplot=max(P');
minplot=min(P');
axis([minplot(1)-10 maxplot(1)+10 minplot(2)-10 maxplot(2)+10 minplot(3)-10 maxplot(3)+10])
view(72,24)
[m n]=size(P);
for i=1:n
if D(i)==1
plot3(P(1,i),P(2,i),P(3,i),'y+')
elseif D(i)==2
plot3(P(1,i),P(2,i),P(3,i),'mo')
elseif D(i)==3
plot3(P(1,i),P(2,i),P(3,i),'rx')
else
plot3(P(1,i),P(2,i),P(3,i),'bo')
end
end
最后运行结果如下: