一、SVM介绍
传统学习方法采用的经验风险最小化方法(ERM)虽然将误差最小化,但不能最小化学习过程的泛化误差。ERM方法不成功的例子就是神经网络中的过学习问题。为此,由Vapnik领导的贝尔实验室研究小组于1963年提出了一种新的非常有潜力的技术,支持向量机(Support vector machine,SVM)是一种基于统计学习理论的模式识别方法,主要用用于模式识别领域。
SVM的基本思想是在样本空间或特征空间构造出最优超平面,使得超平面与不同样本集之间的距离最大,从而达到最大的泛化能力。
关于SVM原理的解析网上有很多详尽的资料了,在这里列出一些我个人觉得解释的不错的。
在看别人的总结中,发现了一些不错的点,在这里摘录一下。(详细的还得去翻上面的两个连接)
1、SVM 算法特性
(1)训练好的模型的算法复杂度是由支持向量的个数决定的,而不是由数据的维度决定的。所以 SVM 不太容易产生 overfitting。
(2)SVM 训练出来的模型完全依赖于支持向量,即使训练集里面所有非支持向量的点都被去除,重复训练过程,结果仍然会得到完全一样的模型。
(3)一个 SVM 如果训练得出的支持向量个数比较少,那么SVM 训练出的模型比较容易被泛化。
2、SVM基本思想的简单描述(不包含推导,只是为了方便我记忆)
(1)分类线方程:
(2)分类间隔:,使分类间隔最大化等价于使
最小化
(3)超平面一边的点都有,另一边的点则有
(4)假设现在有两类数据,标签y为+1和-1,调整w,使超平面定义如下:
综合以上两式,有,这个式子是约束条件
(5)因此这个最小化优化问题可以用拉格朗日方法求解,令
其中为每个样本的拉格朗日乘子。
(6)现在需要最大化,所以问题变成
。根据拉格朗日的对偶性,原始的极小极大问题变成了现在的极大极小问题
(7)求
L分别对b和w求导数并令其等于0,得到:
把上面这两个式子代回,得到
(8)求 对
的极大值
问题变为:
所以只要知道 , 也就知道
了。(关键点就在求解
)
的求解:使用SMO(Sequential Minimal Optimization)即序列最小最优化来求解。(过程比较复杂,不做详解)
得到 后,相当于知道哪些样本点是支持向量了,因为支持向量的
>0,非支持向量的
=0。
(9)然后把 代入
,求解得到
。
再把 代入
,求解得到b。
(10)至此, 和b都得到了,所以代入超平面公式可得超平面
而分类决策函数为:,根据符号来判断测试样本的类别。
二、libsvm在matlab中的使用
(1)首先去官网下载libsvm工具包,将其解压到MATLAB->toolbox的安装路径
(2)在MATLAB软件中设置路径->添加文件夹,将解压后的文件夹添加到工具箱
(3)然后将工作区设置到这个文件夹下的matlab文件夹下,在MATLAB的命令行区执行“make”命令来编译文件(一定要编译),直到文件夹下出现了4个后缀为mexw64的文件
好了,现在可以使用libsvm的函数了。
(1)libsvm的两个主要函数为:svmtrain和svmpredict
关于这两个函数的参数和返回变量的介绍和使用可以看这篇文章:libsvmpredict和svmtrain的参数和返回值
(2)libsvm包含了几种类型的SVM算法,关于他们的介绍可以看这篇文章:libsvm中svm类型简介
下面是一段《模式识别与人工智能(基于matlab)》中的代码。
clear;
clc;
load SVM % 数据都存在SVM.mat这个文件中,要load一下
% 训练数据和标签
% 数据有3个属性,4个类别
% 训练数据有30个
train_train = [train(1:4,:);train(5:11,:);train(12:19,:);train(20:30,:)]; % 手动划4分类
train_target = [target(1:4);target(5:11);target(12:19);target(20:30)];
% 测试数据和标签
% 测试数据有30个
test_simulation = [simulation(1:6,:);simulation(7:11,:);simulation(12:24,:);simulation(25:30,:)];
test_labels = [labels(1:6);labels(7:11);labels(12:24);labels(25:30)];
model = svmtrain(train_target, train_train, '-c 2 -g 0.2 -t 1'); % 核函数为多项式核函数
[predict_label, accuracy, dec_values] = svmpredict(test_labels, test_simulation, model);
% predict_label:预测得到的测试样本的标签
% accuracy:预测准确率
% dec_values:样本分别属于每一个类别的概率
predict_label
hold off
f=predict_label';
index1=find(f==1);
index2=find(f==2);
index3=find(f==3);
index4=find(f==4);
plot3(simulation(:,1),simulation(:,2),simulation(:,3),'o');
line(simulation(index1,1),simulation(index1,2),simulation(index1,3),'linestyle','none','marker','*','color','g');
line(simulation(index2,1),simulation(index2,2),simulation(index2,3),'linestyle','none','marker','<','color','r');
line(simulation(index3,1),simulation(index3,2),simulation(index3,3),'linestyle','none','marker','+','color','b');
line(simulation(index4,1),simulation(index4,2),simulation(index4,3),'linestyle','none','marker','>','color','y');
box;grid on;hold on;
xlabel('A');
ylabel('B');
zlabel('C');
title('支持向量机分析图');
运行结果图为:
命令行窗口结果为:
Accuracy = 96.6667% (29/30) (classification)
predict_label =
1
1
1
1
1
1
2
2
2
2
2
3
3
3
3
3
3
2
3
3
3
3
3
3
4
4
4
4
4
4