最近做个小东西,要用到SVM,搜索网上,发现大伙都是各种介绍理论,让我等小码农晕头转向,是故自己学习总结一下,并将代码实例展示出来,方便大家共同探讨,该代码是用opencv编写的,很容易学习滴。
1、SVM小介绍
SVM是一种用超平面定义的分类器,是一种监督的分类算法。即使用带标签的训练数据,SVM得到优化的超平面,使得两类之间的距离最大,这样有什么好处呢?显而遇见,这样可以降低噪声干扰,因为超平面到数据点的距离是最大距离的一半,只要噪声扰动不要越过超平面即可。
推导过程我就不详写了,因为这个页面写不了公式。我手写了个大概的,拍照上传,见下图:
也许你会问,这个只是对于线性可分问题的,假如对于线性不可分的呢?其实,这个担心是多余的,SVM可通过核函数升维,将不可分问题变为可分。
也许你还问,怎么使用朗格朗日乘子得到权重向量?我觉得,这个方法很常见,尤其在优化算法里面,下篇博文我会讲一下这个算法的原理及代码实现。
2、实例代码
说了那么多,该搞点实际的东西了,要不然我也成理论家了。看下面这个例子:
#include "StdAfx.h"
#include
#include
#include
using namespace cv;
int main()
{
// Data for visual representation
int width = 512, height = 512;
Mat image = Mat::zeros(height, width, CV_8UC3);
// Set up training data
float labels[4] = {1.0, -1.0, -1.0, -1.0};
Mat labelsMat(4, 1, CV_32FC1, labels);
float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} };
Mat trainingDataMat(4, 2, CV_32FC1, trainingData);
// Set up SVM's parameters
CvSVMParams params;
params.svm_type = CvSVM::C_SVC;
params.kernel_type = CvSVM::LINEAR;
params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);
// Train the SVM
CvSVM SVM;
SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params);
Vec3b green(0,255,0), blue (255,0,0);
// Show the decision regions given by the SVM
for (int i = 0; i < image.rows; ++i)
for (int j = 0; j < image.cols; ++j)
{
Mat sampleMat = (Mat_
(1,2) << j,i);
float response = SVM.predict(sampleMat);
if (response == 1)
image.at
(i,j) = green; else if (response == -1) image.at
(i,j) = blue; } // Show the training data int thickness = -1; int lineType = 8; circle( image, Point(501, 10), 5, Scalar( 0, 0, 0), thickness, lineType); circle( image, Point(255, 10), 5, Scalar(255, 255, 255), thickness, lineType); circle( image, Point(501, 255), 5, Scalar(255, 255, 255), thickness, lineType); circle( image, Point( 10, 501), 5, Scalar(255, 255, 255), thickness, lineType); // Show support vectors thickness = 2; lineType = 8; int c = SVM.get_support_vector_count(); for (int i = 0; i < c; ++i) { const float* v = SVM.get_support_vector(i); circle( image, Point( (int) v[0], (int) v[1]), 6, Scalar(128, 128, 128), thickness, lineType); } imshow("SVM Simple Example", image); // show it to the user waitKey(0); }
运行结果:
结果分析,上图是通过SVM分类器实现的类别划分,对于二维数据,超平面可以理解为右上角的分界线,距离超平面最近的点是支持向量,如图画灰环的三个点。什么意思呢?就是说这个超平面主要有着三个点确定。至于代码吧,有注解,大家应该都懂,就不说了。
好了,你应该对SVM有大概的认识了吧。