Matlab 实现线性svm两类和多类分类器

线性分类和SVM原理

网上有很多写的好的博客讲解线性分类和SVM,本人讲解能力差,就给个链接。
http://blog.csdn.net/mm_bit/article/details/46988925

SVM实现代码

训练svm分类器实际上是解二次规划问题,matlab里用到的是quadprog函数,其使用用法可参见matlab官方文档:
http://cn.mathworks.com/help/optim/ug/quadprog.html
或者不想读英文文档的人可以看别人写的博客:
http://blog.csdn.net/jbb0523/article/details/50598641
会使用quadprog函数基本上就会写svm分类器了,这里贴上源码:
svmTrain.m

function [ svm ] = svmTrain( trainData,trainLabel,kertype,C )
options=optimset;
options.LargerScale='off';
options.Display='off';

n=length(trainLabel);
H=(trainLabel'*trainLabel).*kernel(trainData,trainData,kertype);
f=-ones(n,1);
A=[];
b=[];
Aeq=trainLabel;
beq=0;
lb=zeros(n,1);
ub=C*ones(n,1);
a0=zeros(n,1);
[a,fval,eXitflag,output,lambda]=quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options);
epsilon=1e-8;
sv_label=find(abs(a)>epsilon);
svm.a=a(sv_label);
svm.Xsv=trainData(:,sv_label);
svm.Ysv=trainLabel(sv_label);
svm.svnum=length(sv_label);
end

kernel.m(更新)

function K = kernel( X,Y,type )
switch type
    case 'linear'
        K=X'*Y;
    case 'rbf'
        delta=5;
        delta=delta*delta;
        XX=sum(X'.*X',2);
        YY=sum(Y'.*Y',2);
        XY=X'.*Y;
        K=abs(repmat(XX,[1 size(YY,1)])+repmat(YY',[size(XX,1) 1])-2*XY);
        K=exp(-K./delta);
end
end

svmTest.m:

function result = svmTest(svm, Xt, Yt, kertype)  
temp = (svm.a'.*svm.Ysv)*kernel(svm.Xsv,svm.Xsv,kertype);  
%total_b = svm.Ysv-temp;  
b = mean(svm.Ysv-temp);  %b取均值  
w = (svm.a'.*svm.Ysv)*kernel(svm.Xsv,Xt,kertype);  
result.score = w + b;  
Y = sign(w+b);  %f(x)  
result.Y = Y;  
result.accuracy = size(find(Y==Yt))/size(Yt);  
end  

test.m

%------------主函数----------------  
clc;
clear;
C = 10;  %成本约束参数  
kertype = 'linear';  %线性核  

%①------数据准备  
n = 30;  
%randn('state',6);   %指定状态,一般可以不用  
x1 = randn(2,n);    %2行N列矩阵,元素服从正态分布  
y1 = ones(1,n);       %1*N个1  
x2 = 4+randn(2,n);   %2*N矩阵,元素服从正态分布且均值为5,测试高斯核可x2 = 3+randn(2,n);   
y2 = -ones(1,n);      %1*N个-1  

figure;  %创建一个用来显示图形输出的一个窗口对象  
plot(x1(1,:),x1(2,:),'bs',x2(1,:),x2(2,:),'k+');  %画图,两堆点  
axis([-3 8 -3 8]);  %设置坐标轴范围  
hold on;    %在同一个figure中画几幅图时,用此句  

%②-------------训练样本  
X = [x1,x2];        %训练样本2*n矩阵,n为样本个数,d为特征向量个数  
Y = [y1,y2];        %训练目标1*n矩阵,n为样本个数,值为+1或-1  
svm = svmTrain(X,Y,kertype,C);  %训练样本  
plot(svm.Xsv(1,:),svm.Xsv(2,:),'ro');   %把支持向量标出来  

%③-------------测试  
[x1,x2] = meshgrid(-2:0.05:7,-2:0.05:7);  %x1和x2都是181*181的矩阵  
[rows,cols] = size(x1);    
nt = rows*cols;                    
Xt = [reshape(x1,1,nt);reshape(x2,1,nt)];  
%前半句reshape(x1,1,nt)是将x1转成1*(181*181)的矩阵,所以xt是2*(181*181)的矩阵  
%reshape函数重新调整矩阵的行、列、维数  
Yt = ones(1,nt);  

result = svmTest(svm, Xt, Yt, kertype);  

%④--------------画曲线的等高线图  
Yd = reshape(result.Y,rows,cols);  
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线  
title('svm分类结果图');     
x1=xlabel('X轴');    
x2=ylabel('Y轴');   

分类结果:
这里写图片描述

多类分类器

多类线性问题的一种解法是将多类分为多个两类分类器。
如训练集有1,2,3,4个类,让1与2作两类分类器训练得到12分类器,1与3得到13分类器,等等两两训练,得到(4-1)*4/2 6个分类器,再将测试数据用这6个分类器一一测试,如果对12,13,14三个都是正的,则该类属于1类。
代码:
multiLiner:

clc;
clear;
C=10;
kertype='linear';
%生成测试数据
n=30;
x1=randn(2,n);
x2=4+randn(2,n);

x3=randn(2,n);
x3=[x3(1,:)+4;x3(2,:)-4];

x4=randn(2,n);
x4=[x4(1,:)+8;x4(2,:)];
%可视化生成数据
plot(x1(1,:),x1(2,:),'bs',x2(1,:),x2(2,:),'k+');
hold on;
plot(x3(1,:),x3(2,:),'r*',x4(1,:),x4(2,:),'y.');
axis([-3 11 -7 7]);
hold on;
%两两合成一个训练组训练模型
trainData12=[x1,x2];
trainData13=[x1,x3];
trainData14=[x1,x4];
trainData23=[x2,x3];
trainData24=[x2,x4];
trainData34=[x3,x4];
trainLabel=[ones(1,n),-ones(1,n)];

svm_12=svmTrain(trainData12,trainLabel,kertype,C);
svm_13=svmTrain(trainData13,trainLabel,kertype,C);
svm_14=svmTrain(trainData14,trainLabel,kertype,C);
svm_23=svmTrain(trainData23,trainLabel,kertype,C);
svm_24=svmTrain(trainData24,trainLabel,kertype,C);
svm_34=svmTrain(trainData34,trainLabel,kertype,C);

%生成测试数据
[x1,x2] = meshgrid(-2:0.05:10,-6:0.05:6);  %x1和x2都是181*181的矩阵  
[rows,cols] = size(x1);    
nt = rows*cols;                    
Xt = [reshape(x1,1,nt);reshape(x2,1,nt)];  
%前半句reshape(x1,1,nt)是将x1转成1*(181*181)的矩阵,所以xt是2*(181*181)的矩阵  
%reshape函数重新调整矩阵的行、列、维数  
Yt = ones(1,nt);  
result12=svmTest(svm_12,Xt,Yt,kertype);
Yd = reshape(result12.Y,rows,cols); 
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线  
hold on;
result13=svmTest(svm_13,Xt,Yt,kertype);
Yd = reshape(result13.Y,rows,cols); 
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线  
hold on;
result14=svmTest(svm_14,Xt,Yt,kertype);
Yd = reshape(result14.Y,rows,cols); 
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线  
hold on;
result23=svmTest(svm_23,Xt,Yt,kertype);
Yd = reshape(result23.Y,rows,cols); 
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线  
hold on;
result24=svmTest(svm_24,Xt,Yt,kertype);
Yd = reshape(result24.Y,rows,cols); 
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线  
hold on;
result34=svmTest(svm_34,Xt,Yt,kertype);
Yd = reshape(result34.Y,rows,cols); 
contour(x1,x2,Yd,[0,0],'ShowText','on');%画等高线

%测试一个样本点属于哪一类
Xt=[10;2];
Yt=1;
result12=svmTest(svm_12,Xt,Yt,kertype);
result13=svmTest(svm_13,Xt,Yt,kertype);
result14=svmTest(svm_14,Xt,Yt,kertype);
result23=svmTest(svm_23,Xt,Yt,kertype);
result24=svmTest(svm_24,Xt,Yt,kertype);
result34=svmTest(svm_34,Xt,Yt,kertype);
if result12.Y==1&&result13.Y==1&&result14.Y==1
    testLabel=1;
elseif result12.Y==-1&&result23.Y==1&&result24.Y==1
    testLabel=2;
elseif result13.Y==-1&&result23.Y==-1&&result34.Y==1
    testLabel=3;
elseif result14.Y==-1&&result24.Y==-1&&result34.Y==-1
    testLabel=4;
else
    testLabel=-1;
    disp('测试点不属于这4类中');
end

分类结果:
这里写图片描述

  • 23
    点赞
  • 334
    收藏
    觉得还不错? 一键收藏
  • 16
    评论
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值