matlab利用hinge loss实现多分类SVM

1 介绍

本文将介绍hinge loss E(w) 以及其梯度 E(w) 。并利用批量梯度下降方法来优化hinge loss实现SVM多分类。利用hinge loss在手写字数据库上实验,能达到87.040%的正确识别率。


2. hinge loss

  1. 根据二分类的SVM目标函数,我们可以定义多分类的SVM目标函数:
    E(w1,,wk)=kj=112||wj||2+Cni=1L((w1,,wk),(xi,yi)) .

其中 T={(x1,y1),,(xn,yn)} 为训练集。 L((w1,,wk),(x,y))=max(0,maxyywTyx+1wTyx) . 二分类SVM转化为多分类SVM的相关资料和公式推导可以参见其他文献。
2. 接下介绍 E(w) 的梯度计算。
(a) 如果 wTywTy^x+1 , 那么

L((w1,w2,,wk),(x,y))wj,l=0

(b) 如果 wTy<wTy^x+1 j=y , 那么

L((w1,w2,,wk),(x,y))wj,l=xl

(c) 如果 wTy<wTy^x+1 j=y^ , 那么

L((w1,w2,,wk),(x,y))wj,l=xl

(d) 如果 wTy<wTy^x+1 jy and jy^ , 那么

L((w1,w2,,wk),(x,y))wj,l=0

  1. 利用梯度下降法更新 W={w1,,wk} :
    Wt=Wt1rE(Wt1)

3 code

Muliticlass_svm.m

% 作者:何凌霄
% 中科院自动化所
% 2017315
clear all
clc
%% STEP 0: Initialise constants and parameters
inputSize = 28 * 28; % Size of input vector (MNIST images are 28x28)
numClasses = 10;     % Number of classes (MNIST images fall into 10 classes)
lambda = 1e-2; % Weight decay parameter
learning_rate = 0.1;
iteration=400;
%%======================================================================
%% STEP 1: Load data
load('digits.mat')
images = [train1; train2; train3; train4; train5; train6; train7; train8; train9;train0];
images = images';
labels = [ones(500,1);2*ones(500,1);3*ones(500,1);4*ones(500,1);5*ones(500,1);6*ones(500,1);7*ones(500,1);8*ones(500,1);9*ones(500,1);10*ones(500,1)];
index = randperm(500*10);
images = images(:,index);
labels = labels(index);
inputData = images;
%% STEP 2: Train multiclass svm
[cost, grad, svmOptTheta] = multisvmtrain(numClasses, inputSize, lambda, inputData, labels, iteration, learning_rate);
%% STEP 3: Test
images = [test1; test2; test3; test4; test5; test6; test7; test8; test9;test0];
images = images';
labels = [ones(500,1);2*ones(500,1);3*ones(500,1);4*ones(500,1);5*ones(500,1);6*ones(500,1);7*ones(500,1);8*ones(500,1);9*ones(500,1);10*ones(500,1)];

inputData = images;
svmModel.optTheta = reshape(svmOptTheta, numClasses, inputSize);
svmModel.inputSize = inputSize;
svmModel.numClasses = numClasses;

% You will have to implement softmaxPredict in softmaxPredict.m
[pred] = Multi_SVMPredict(svmModel, inputData);
acc = mean(labels(:) == pred(:));
num_in_class = 500*ones(10,1)';
for i=1:10
    name_class{i}=num2str(i);
end
[confusion_matrix]=compute_confusion_matrix(pred,num_in_class,name_class);
figure; visualize(svmOptTheta');
fprintf('Accuracy: %0.3f%%\n', acc * 100);

multisvmtrain.m

% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [lcost, grad, theta] = multisvmtrain(numClasses, inputSize, lambda, data, labels, iteration, learning_rate)
theta = 0.005 * randn(numClasses * inputSize, 1);
theta = reshape(theta, numClasses, inputSize);%将输入的参数列向量变成一个矩阵
numCases = size(data, 2);%输入样本的个数
groundTruth = full(sparse(labels, 1:numCases, 1));%这里sparse是生成一个稀疏矩阵,该矩阵中的值都是第三个值1
cost = 0;
thetagrad = zeros(numClasses, inputSize);
for i = 1:iteration
    [Q, X, cost] = multi_hingeloss_cost(theta, data, groundTruth,lambda);
    [thetagrad] = multi_hingeloss_grad(data,theta, Q, groundTruth, lambda, labels);
    theta = theta - learning_rate*thetagrad;
    lcost(i) = cost;
    grad(i) = sum(sum(thetagrad));
    fprintf('%d, %f\n', i, cost);
end
end

multi_hingeloss_cost.m

% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [Q, X, cost] = multi_hingeloss_cost(theta, data, groundTruth,lambda)
groundTruth1 = groundTruth;
groundTruth(find(groundTruth==1)) = -inf;  
groundTruth(find(groundTruth==0)) = 1; 
X = theta*data;
Q = X;
Q = Q.*groundTruth;
Q(find(Q==inf)) = -inf;
temp = X.*groundTruth1;
temp(find(temp==0))=[];
t = max(0, 1 - temp + max(Q));
cost = 1/size(data,2)*sum(t)+lambda*sum(theta(:).^2);

multi_hingeloss_grad.m

% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [thetagrad] = multi_hingeloss_grad(data, theta, Q, groundTruth, lambda, labels)
X = theta*data;
[~,q] = max(Q);
Xq = full(sparse(q, 1:size(X,2), 1));
if size(Xq,1)<10
    for i = 1:10-size(Xq,1)
        Xq = [Xq;zeros(1, size(Xq,2))];
    end
end
temp = X.*groundTruth;
temp1 = X.*Xq;
temp1(find(temp1==0))=[];
temp(find(temp==0))=[];
W=(temp - temp1)<1;
Y = zeros(size(X));

for i=1:size(X,2)
    Y(labels(i),i) = -W(i);
    Y(q(i),i) = W(i);
end
thetagrad = 1/size(X,2)*Y*data' + lambda * theta;

Multi_SVMPredict.m

% 作者:何凌霄
% 中科院自动化所
% 2017年3月15
function [pred] = Multi_SVMPredict(svmModel, data)
theta = svmModel.optTheta;  % this provides a numClasses x inputSize matrix
pred = zeros(1, size(data, 2));
[nop, pred] = max(theta * data);
end

compute_confusion_matrix.m

[confusion_matrix]=compute_confusion_matrix(predict_label,num_in_class,name_class)%预测标签,每一类的数目,类别数目  
%predict_label为一维行向量  
%num_in_class代表每一类的个数  
%name_class代表类名  
num_class=length(num_in_class);  
num_in_class=[0 num_in_class];  
confusion_matrix=size(num_class,num_class);  

for ci=1:num_class  
    for cj=1:num_class  
        summer=0;%统计对应标签个数  
        c_start=sum(num_in_class(1:ci))+1;  
        c_end=sum(num_in_class(1:ci+1));  
        summer=size(find(predict_label(c_start:c_end)==cj),2);  
        confusion_matrix(ci,cj)=summer/num_in_class(ci+1);  
    end  
end  

draw_cm(confusion_matrix,name_class,num_class);  

end  

function draw_cm.m

function draw_cm(mat,tick,num_class)  

imagesc(1:num_class,1:num_class,mat);            %# in color  
colormap(flipud(gray));  %# for gray; black for large value.  

textStrings = num2str(mat(:),'%0.2f');    
textStrings = strtrim(cellstr(textStrings));   
[x,y] = meshgrid(1:num_class);   
hStrings = text(x(:),y(:),textStrings(:), 'HorizontalAlignment','center');  
midValue = mean(get(gca,'CLim'));   
textColors = repmat(mat(:) > midValue,1,3);   
set(hStrings,{'Color'},num2cell(textColors,2));  %# Change the text colors  

set(gca,'xticklabel',tick,'XAxisLocation','top');  
set(gca, 'XTick', 1:num_class, 'YTick', 1:num_class);  
set(gca,'yticklabel',tick);  
rotateXLabels(gca, 315 );% rotate the x tick  

visualize.m

function r=visualize(X, mm, s1, s2)
%FROM RBMLIB http://code.google.com/p/matrbm/
%Visualize weights X. If the function is called as a void method,
%it does the plotting. But if the function is assigned to a variable 
%outside of this code, the formed image is returned instead.
if ~exist('mm','var')
    mm = [min(X(:)) max(X(:))];
end
if ~exist('s1','var')
    s1 = 0;
end
if ~exist('s2','var')
    s2 = 0;
end

[D,N]= size(X);
s=sqrt(D);
if s==floor(s) || (s1 ~=0 && s2 ~=0)
    if (s1 ==0 || s2 ==0)
        s1 = s; s2 = s;
    end
    %its a square, so data is probably an image
    num=ceil(sqrt(N));
    a=mm(2)*ones(num*s2+num-1,num*s1+num-1);
    x=0;
    y=0;
    for i=1:N
        im = reshape(X(:,i),s1,s2)';
        a(x*s2+1+x : x*s2+s2+x, y*s1+1+y : y*s1+s1+y)=im;
        x=x+1;
        if(x>=num)
            x=0;
            y=y+1;
        end
    end
    d=true;
else
    %there is not much we can do
    a=X;
end

%return the image, or plot the image
if nargout==1
    r=a;
else

    imagesc(a, [mm(1) mm(2)]);
    axis equal
    colormap gray

end

得到的识别率为87.040%,hinge loss可以和任何深度网络结合完成分类任务。
最后得到的混淆矩阵如下:
这里写图片描述

损失函数图像:
这里写图片描述

数据集见资源,如引用此代码,请注明出处。

  • 12
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 17
    评论
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值