MATLAB实现SVM多分类(one-vs-rest),利用自带函数fitcsvm

SVM多分类

SVM也叫支持向量机,其是一个二类分类器,但是对于多分类,SVM也可以实现。主要方法就是训练多个二类分类器。常见的有以下两种方式:

一对一(one-vs-one)

给定m个类,对m个类中的每两个类都训练一个分类器,总共的二类分类器个数为 m(m-1)/2 .比如有三个类,1,2,3,那么需要有三个分类器,分别是针对:1和2类,1和3类,2和3类。对于一个需要分类的数据x,它需要经过所有分类器的预测,最后使用投票的方式来决定x最终的类属性。

一对多(one-vs-rest)

给定m个类,需要训练m个二类分类器。其中的分类器 i 是将 i 类数据设置为类1(正类),其它所有m-1个i类以外的类共同设置为类2(负类),这样,针对每一个类都需要训练一个二类分类器,最后,我们一共有 m 个分类器。对于一个需要分类的数据 x,通常选择置信度最大的类别标记为分类结果。

fitcsvm简单介绍

在新版本中svmtrain和svmclassify函数提示已经被移除,所以我们应该跟上潮流学习使用fitcsvm。

// An highlighted block
SVMModel =  fitcsvm(X,Y,'ClassNames',{'negClass','posClass'},'Standardize',true,...
        'KernelFunction','rbf','BoxConstraint',1);

简单说一下参数:
X是训练样本,nxm的矩阵,n是样本数,m是特征维数;
Y是样本标签,nx1的矩阵,n是样本数;
‘ClassNames’,{‘negClass’,‘posClass’} 为键值对参数,指定正负类别,负类名在前,正类名在后,与样本标签Y中的元素对应;
‘Standardize’,true 为键值对参数,指示软件是否应在训练分类器之前使预测期标准化!
‘KernelFunction’,‘rbf’ 为键值对参数,有3种 ‘linear’(默认), ‘gaussian’ (or ‘rbf’), ‘polynomial’
‘BoxConstraint’,1 为键值对参数,直观上可以理解为一个惩罚因子(或者说正则参数),这个参数和svmtrain里的-c是一个道理。其实际上涉及到软间隔SVM的间隔(Margin)大小。
基本思想如下:当原始数据未能呈现出较好的可分性时,算法允许其在训练集上呈现出一些误分类,matlab默认的BoxConstraint为1。框约束的数值越大,意味着惩罚力度越小,最后得到的分类超平面的间隔越小,支持向量数越多,模型越复杂。这也就是很多机器学习理论书中一开始推导的硬间隔支持向量机(Hard-Margin SVM)。因为该参数默认为1,所以使用默认参数训练时,我们采用的是软间隔SVM。
更详细的大家可以参考官方说明文档 [https://ww2.mathworks.cn/help/stats/fitcsvm.html].

代码

说一下思路:
1.我自己造的数据不用太关心,训练数据是60x2,60是样本数,2是特征数;测试数据是20x2的。
2.目标是分5类,一对多的方式,就要分别训练5个SVM模型;每个模型都是一个二分类,所以需要正、负样本的划分。我是这么做的正样本全部来自该类别,负样本从其它4个类别中随机选择,但数目与正样本相同。有了每一类的正、负样本,这就得到了训练样本X;再设定标签,我设的是+1,-1,这就得到了样本标签Y;其它参数均默认不设,这样就可以为每一类样本训练SVM模型了。
3.测试样本并不需要对每一类划分正、负样本,只要知道测试数据和样本标签即可。
4.每个测试样本在5个SVM模型中均得到一个得分score,利用最大得分判定该样本最终属于哪一类。
5.这个混淆矩阵函数confusionmat是真的好用,只需要知道真实标签和预测标签就能算出查准率(precision)、查全率(recall)和综合评价指标(F-measure)。
如图:

哈哈哈
类别1的查准率 = a / ( a + d + g ) =a/(a+d+g) =a/(a+d+g)
类别1的查全率 = a / ( a + b + c ) =a/(a+b+c) =a/(a+b+c)
类别2的查准率 = e / ( b + e + h ) =e/(b+e+h) =e/(b+e+h)
类别2的查全率 = e / ( d + e + f ) =e/(d+e+f) =e/(d+e+f)
···

// An highlighted block
clc;
clear;
close all;
tic
fprintf('-----已开始请等待-----\n\n');
%% 造数据不用关心,直接跳过
% 造数据 20*2
data = [0.4,0.3;-0.5,0.1;-0.2,-0.3;0.5,-0.3;
        2.1,1.9;1.8,2.2;1.7,2.5;2.3,1.6;
        -2.2,1.6;-1.9,2.1;-1.7,2.6;-2.3,2.5;
        -3.1,-1.9;-2.8,-2.1;-1.9,-2.5;-2.3,-3.2;
        3.9,-3.5;2.8,-2.2;1.7,-3.1;2.5,-3.4];
data1 = data + 2.5*rand(20,2);
data2 = data + 2.5*rand(20,2);
data3 = data + 2.5*rand(20,2); data1(17:20,:);
% 训练数据
train_data = [data1(1:4,:);data2(1:4,:);data3(1:4,:);
              data1(5:8,:);data2(5:8,:);data3(5:8,:);
              data1(9:12,:);data2(9:12,:);data3(9:12,:);
              data1(13:16,:);data2(13:16,:);data3(13:16,:);
              data1(17:20,:);data2(17:20,:);data3(17:20,:)];
                  
% 画图显示
figure;
% gscatter函数可以按分类或者分组画离散点
% group为分组向量,对应每一个坐标的类别
group_train = [1;1;1;1;1;1;1;1;1;1;1;1;
         2;2;2;2;2;2;2;2;2;2;2;2;
         3;3;3;3;3;3;3;3;3;3;3;3;
         4;4;4;4;4;4;4;4;4;4;4;4;
         5;5;5;5;5;5;5;5;5;5;5;5];
gscatter(train_data(:,1),train_data(:,2),group_train);

title('训练数据样本分布');
xlabel('样本特征1');
ylabel('样本特征2');
legend('Location','Northwest');
grid on;

%%
% 测试数据
test_data = data + 3.0*rand(20,2);
test_features = test_data;
% 测试数据的真实标签
test_labels = [1;1;1;1;2;2;2;2;3;3;3;3;4;4;4;4;5;5;5;5];

%%
% 训练数据分为5% 类别i的 正样本 选择类别i的全部,负样本 从其余类别中随机选择(个数与正样本相同)
% 类别1
class1_p = train_data(1:12,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(1:12,:) = [];
class1_n = train_data_c(index1,:);

train_features1 = [class1_p;class1_n];
% 正类表示为1,负类表示为-1
train_labels1 = [ones(12,1);-1*ones(12,1)];

% 类别2
class2_p = train_data(13:24,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(13:24,:) = [];
class2_n = train_data_c(index1,:);

train_features2 = [class2_p;class2_n];
% 正类表示为1,负类表示为-1
train_labels2 = [ones(12,1);-1*ones(12,1)];

% 类别3
class3_p = train_data(25:36,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(25:36,:) = [];
class3_n = train_data_c(index1,:);

train_features3 = [class3_p;class3_n];
% 正类表示为1,负类表示为-1
train_labels3 = [ones(12,1);-1*ones(12,1)];

% 类别4
class4_p = train_data(37:48,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(37:48,:) = [];
class4_n = train_data_c(index1,:);

train_features4 = [class4_p;class4_n];
% 正类表示为1,负类表示为-1
train_labels4 = [ones(12,1);-1*ones(12,1)];

% 类别5
class5_p = train_data(49:60,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(49:60,:) = [];
class5_n = train_data_c(index1,:);

train_features5 = [class5_p;class5_n];
% 正类表示为1,负类表示为-1
train_labels5 = [ones(12,1);-1*ones(12,1)];

%%
% 分别训练5个类别的SVM模型
model1 = fitcsvm(train_features1,train_labels1,'ClassNames',{'-1','1'});
model2 = fitcsvm(train_features2,train_labels2,'ClassNames',{'-1','1'});
model3 = fitcsvm(train_features3,train_labels3,'ClassNames',{'-1','1'});
model4 = fitcsvm(train_features4,train_labels4,'ClassNames',{'-1','1'});
model5 = fitcsvm(train_features5,train_labels5,'ClassNames',{'-1','1'});
fprintf('-----模型训练完毕-----\n\n');
%%
% label是n*1的矩阵,每一行是对应测试样本的预测标签;
% score是n*2的矩阵,第一列为预测为“负”的得分,第二列为预测为“正”的得分。
% 用训练好的5SVM模型分别对测试样本进行预测分类,得到5个预测标签
[label1,score1] = predict(model1,test_features);
[label2,score2] = predict(model2,test_features);
[label3,score3] = predict(model3,test_features);
[label4,score4] = predict(model4,test_features);
[label5,score5] = predict(model5,test_features);
% 求出测试样本在5个模型中预测为“正”得分的最大值,作为该测试样本的最终预测标签
score = [score1(:,2),score2(:,2),score3(:,2),score4(:,2),score5(:,2)];
% 最终预测标签为k*1矩阵,k为预测样本的个数
final_labels = zeros(20,1);
for i = 1:size(final_labels,1)
    % 返回每一行的最大值和其位置
    [m,p] = max(score(i,:));
    % 位置即为标签
    final_labels(i,:) = p;
end
fprintf('-----样本预测完毕-----\n\n');
% 分类评价指标

group = test_labels; % 真实标签
grouphat = final_labels; % 预测标签
[C,order] = confusionmat(group,grouphat,'Order',[1;2;3;4;5]); % 'Order'指定类别的顺序
c1_p = C(1,1) / sum(C(:,1));
c1_r = C(1,1) / sum(C(1,:));
c1_F = 2*c1_p*c1_r / (c1_p + c1_r);
fprintf('c1类的查准率为%f,查全率为%f,F测度为%f\n\n',c1_p,c1_r,c1_F);

c2_p = C(2,2) / sum(C(:,2));
c2_r = C(2,2) / sum(C(2,:));
c2_F = 2*c2_p*c2_r / (c2_p + c2_r);
fprintf('c2类的查准率为%f,查全率为%f,F测度为%f\n\n',c2_p,c2_r,c2_F);

c3_p = C(3,3) / sum(C(:,3));
c3_r = C(3,3) / sum(C(3,:));
c3_F = 2*c3_p*c3_r / (c3_p + c3_r);
fprintf('c3类的查准率为%f,查全率为%f,F测度为%f\n\n',c3_p,c3_r,c3_F);

c4_p = C(4,4) / sum(C(:,4));
c4_r = C(4,4) / sum(C(4,:));
c4_F = 2*c4_p*c4_r / (c4_p + c4_r);
fprintf('c4类的查准率为%f,查全率为%f,F测度为%f\n\n',c4_p,c4_r,c4_F);

c5_p = C(5,5) / sum(C(:,5));
c5_r = C(5,5) / sum(C(5,:));
c5_F = 2*c5_p*c5_r / (c5_p + c5_r);
fprintf('c5类的查准率为%f,查全率为%f,F测度为%f\n\n',c5_p,c5_r,c5_F);  
            
            
figure;
subplot(121);
% gscatter函数可以按分类或者分组画离散点
% group为分组向量,对应每一个坐标的类别
group_test = test_labels;
gscatter(test_data(:,1),test_data(:,2),group_test);

title('测试数据样本真实分布');
xlabel('样本特征1');
ylabel('样本特征2');
legend('Location','Northwest');
grid on;

subplot(122);
% gscatter函数可以按分类或者分组画离散点
% group为分组向量,对应每一个坐标的类别
group_test = final_labels;
gscatter(test_data(:,1),test_data(:,2),group_test);

title('测试数据样本预测分布');
xlabel('样本特征1');
ylabel('样本特征2');
legend('Location','Northwest');
grid on;

实验结果图

在这里插入图片描述
在这里插入图片描述

-----已开始请等待-----

-----模型训练完毕-----

-----样本预测完毕-----

c1类的查准率为0.375000,查全率为0.750000,F测度为0.500000

c2类的查准率为0.800000,查全率为1.000000,F测度为0.888889

c3类的查准率为1.000000,查全率为0.750000,F测度为0.857143

c4类的查准率为1.000000,查全率为0.250000,F测度为0.400000

c5类的查准率为1.000000,查全率为0.750000,F测度为0.857143

第一次写博客,还请大家多多包涵,欢迎指教!

参考资料:
[https://www.cnblogs.com/litthorse/p/9303711.html].
[https://blog.csdn.net/qq_39328617/article/details/95207473].
[https://baijiahao.baidu.com/s?id=1619821729031070174&wfr=spider&for=pc].

  • 136
    点赞
  • 626
    收藏
    觉得还不错? 一键收藏
  • 33
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 33
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值