基于 MATLAB fitcsvm 的 OVR SVM 多分类器实现

本代码参考 MATLAB 官方例程,实现 SVM 多分类器(one-versus-rest)。对于未知标签的测试集,初始化中测试样本的标签,以及结果分析中的验证部分可忽略。

OVR SVM

一对多(one-versus-rest)模型的基本策略:训练时依次将某一类别的样本归为正样本集,其余类的样本归为负样本集;这样对于拥有 n n n 种类型的样本集就可以构造出 n n n 个二分类器。

一般取获得最高分数的分类器对应的正样本集,为当前测试样本的类;对于所有分类器输出负分数的情况,此时测试样本可能不属于任一分类器对应的正样本类;对于多个分类器输出正分数的情况,可以考虑进行多一步骤的判决。

clear;
close all;

%--------------------------------- 初始化 --------------------------------%
% X % 训练样本, 第一个维度为样本数, 第二个维度为特征数
% testX % 测试样本
% Y % 训练样本的标签; 字符向量元胞数组表示
% testY % 测试样本的标签

%----------------------------- 训练一对多分类器 ---------------------------%
classes = unique(Y); % 标签排序去重
SVMModels = cell(size(classes,1),1); % 储存分类器的元胞数组

for j = 1:numel(classes)
	% 通过字符串比较为分类器设置标签, 仅 classes{j} 类标签为 true
    indx = strcmp(Y,classes{j});
    % 核函数根据样本与特征维度选择, 可以配置 'OptimizeHyperparameters' 开启优化
    SVMModels{j} = fitcsvm(X,indx,'ClassName',[false true],'Standardize',true,...
        'KernelFunction','rbf','BoxConstraint',1);
end

%------------------------ 用一对多分类器对测试集分类 ------------------------%
testNum = size(testX,1);
Scores = zeros(testNum,numel(classes));
for j = 1:numel(classes)
    [~,score] = predict(SVMModels{j}, testX);
    Scores(:,j) = score(:,2); % score 第二行为正类的分数
end

% 测试集对于每一个类的分类器都有获得一个分数, 取获得最大分数的类的索引作为标签
[~,maxScore] = max(Scores,[],2); 

%------------------------------ 结果分析 ----------------------------------%
% 交叉验证
for j = 1:numel(classes)
    CVSVMModel = crossval(SVMModels{j});
    classLoss = kfoldLoss(CVSVMModel);
    disp(['类 ' classes{j} '  泛化率 = ' num2str(classLoss)]);
end

% 检验
err = 0;
err2 = 0;
for j = 1:testNum
    % 每一类 SVM 计算都为负值的测试样本判为不属于任何一类
    if Scores(j,maxScore(j)) < 0
        err2 = err2+1;
        continue;
    end
    if ~strcmp(testY{j},classes{maxScore(j)})
        err = err+1;
    end
end

disp(['虚警率 = ' num2str(err/testNum) '  准确率 = ' num2str((testNum-err-err2)/testNum)]);
测试数据

附上造的一组测试数据。

%------------------------ 初始化 ----------------------%
trainNum = 2e3; % 训练样本数
testNum = 2e3; % 测试样本数
X = zeros(trainNum,1);
testX = zeros(testNum,1);
Y = cell(trainNum,1);
testY = cell(testNum,1);

%----------------------- 训练样本 -----------------------%
for j = 1:trainNum
    c = mod(randi(100),3);
    switch c
        case 0
            X(j) = rand();
            Y{j} = 'Class1';
        case 1
            X(j) = rand()+0.5;
            Y{j} = 'Class2';
        case 2
            X(j) = rand()+1;
            Y{j} = 'Class3';
    end
end

%---------------------- 测试样本 -----------------------%
for j = 1:trainNum
    c = mod(randi(100),3);
    switch c
        case 0
            testX(j) = rand();
            testY{j} = 'Class1';
        case 1
            testX(j) = rand()+0.5;
            testY{j} = 'Class2';
        case 2
            testX(j) = rand()+1;
            testY{j} = 'Class3';
    end
end
  • 4
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值