opencv Modules SVM使用笔记
写在前面:
最近在完成一个项目时,提取出的样本数据较大,自己也没有良好的处理数据的能力,就想着利用现成的 SVM对数据进行分类。
查了很多资料,对SVM的大致原理有点了解,网络上也有开源的LibSvm,可是在没有完全理解SVM原理前,阅读和修改源码有些难度。而opencv恰好有一个Machine Learning模块,囊括了大多数机器学习的算法,而SVM也在其中。
这篇笔记不包含SVM算法的具体原理,只是对opencv中的SVM使用方法做一点拙劣的笔记。下面开始:
SVM是什么?
SVM:Support Vector Machine,即支持向量机。通过Wiki,有个大体的认识支持向量机
源码阅读
源代码来自Introduction to Support Vector Machines,opencv3.1.0的官方网站。自己做了一些注释。
#include <iostream>
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include "opencv2/imgcodecs.hpp"
#include <opencv2/highgui.hpp>
#include <opencv2/ml.hpp> //ml模块
#define NTRAINING_SAMPLES 100 // 每类的训练样本
#define FRAC_LINEAR_SEP 0.9f // Fraction of samples which compose the linear separable part
using namespace cv;
using namespace cv::ml; //ml命名空间
using namespace std;
int main()
{
const int WIDTH = 512, HEIGHT = 512;
Mat I = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
//---------- 1. 随机生成训练数据(二维数据,超平面为二维) ---
//训练200个样本,每个样本为2维向量
Mat trainData(2 * NTRAINING_SAMPLES, 2, CV_32FC1);
//200*1,用以标注样本类别
Mat labels(2 * NTRAINING_SAMPLES, 1, CV_32SC1);
RNG rng(100); // 随机数生成类
// 设置数据中线性可分部分
int nLinearSamples = (int)(FRAC_LINEAR_SEP * NTRAINING_SAMPLES);
// 为类别1生成随机数
Mat trainClass = trainData.rowRange(0, nLinearSamples);
// x坐标范围为[0,0.4*WIDTH]
Mat c = trainClass.colRange(0, 1);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * WIDTH));
// y坐标范围为[0,WIDTH)
c = trainClass.colRange(1, 2);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));
// 为类别2生成随机数
trainClass = trainData.rowRange(2 * NTRAINING_SAMPLES - nLinearSamples, 2 * NTRAINING_SAMPLES);
// x坐标范围为(0.4*WIDTH,1*WIDTH)
c = trainClass.colRange(0, 1);
rng.fill(c, RNG::UNIFORM, Scalar(0.6*WIDTH), Scalar(WIDTH));
// y坐标范围为[0,1*HEIGHT]
c = trainClass.colRange(1, 2);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));
//------------------ 设置训练数据中线性不可分部分 ---------------
// 为类别1,2生成数据
trainClass = trainData.rowRange(nLinearSamples, 2 * NTRAINING_SAMPLES - nLinearSamples);
// x坐标范围是 [0.4*WIDTH, 0.6*WIDTH)
c = trainClass.colRange(0, 1);
rng.fill(c, RNG::UNIFORM, Scalar(0.4*WIDTH), Scalar(0.6*WIDTH));
// y坐标范围是 [0, 1*HEIGHT)
c = trainClass.colRange(1, 2);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));
//-------- 分类(为类别1,2设置标志)----------
labels.rowRange(0, NTRAINING_SAMPLES).setTo(1); // 类1
labels.rowRange(NTRAINING_SAMPLES, 2 * NTRAINING_SAMPLES).setTo(2); // 类2
//------------------ 2. 设置支持向量机范围 -------------
//------------------- 3. 训练SVM -------------------------
cout << "Starting training process" << endl;
Ptr<SVM> svm = SVM::create(); //声明SVM对象
svm->setType(SVM::C_SVC); //SVM模型选择
svm->setC(0.2); //惩罚因子设置(原始0.1)
svm->setKernel(SVM::LINEAR); //核函数类型:线性
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, (int)1e7, 1e-6)); //迭代要求
svm->train(trainData, ROW_SAMPLE, labels); //训练
cout << "Finished training process" << endl;
//------------------------ 4. 显示结果区域 ----------------------------------------
Vec3b green(0, 100, 0), blue(100, 0, 0);
for (int i = 0; i < I.rows; ++i)
for (int j = 0; j < I.cols; ++j)
{
//Mat_<float>为Mat1f,一维向量v=(i,j)
Mat = (Mat1f(1, 2) << i, j);
//Mat sampleMat = (Mat_<float>(1, 2) << i, j);
//判别是哪一类
if (response == 1) I.at<Vec3b>(j, i) = green;
else if (response == 2) I.at<Vec3b>(j, i) = blue;
}
//----------------------- 5. 显示训练数据-------------------------------
int thick = -1;
int lineType = 8;
float px, py;
// 类 1
for (int i = 0; i < NTRAINING_SAMPLES; ++i)
{
px = trainData.at<float>(i, 0);
py = trainData.at<float>(i, 1);
circle(I, Point((int)px, (int)py), 3, Scalar(0, 255, 0), thick, lineType);
}
// 类2
for (int i = NTRAINING_SAMPLES; i <2 * NTRAINING_SAMPLES; ++i)
{
px = trainData.at<float>(i, 0);
py = trainData.at<float>(i, 1);
circle(I, Point((int)px, (int)py), 3, Scalar(255, 0, 0), thick, lineType);
}
//------------------------- 6. 画出支持向量 ----------------------------
thick = 2;
lineType = 8;
Mat sv = svm->getUncompressedSupportVectors();
for (int i = 0; i < sv.rows; ++i)
{
const float* v = sv.ptr<float>(i);
circle(I, Point((int)v[0], (int)v[1]), 6, Scalar(0, 0, 0), thick, lineType);
}
imwrite("result.png", I); // save the Image
imshow("SVM for Non-Linear Training Data", I); // show it to the user
waitKey(0);
}
运行环境是VS2015+opencv3.1.0。
思路阐述
以上代码通过core.hpp中的RNG随机数生成类,生成了200个二维向量作为训练数据,数值大小在[0,512)中,以便可以用一张图片表示,存储在一个200*2的Mat结构里。同时生成一个Mat的label,将x坐标在一定范围内点标识为类1,其余为类2。生成的训练数据二维可分,故超平面为二维平面内的直线。
下面通过创建一个SVM类,调用几个函数,设定参数,并训练数据。
Ptr<SVM> svm = SVM::create(); //声明SVM对象
svm->setType(SVM::C_SVC); //SVM模型选择
svm->setC(0.1); //惩罚因子设置
svm->setKernel(SVM::LINEAR); //核函数类型:线性
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER,
(int)1e7, 1e-6)); //迭代要求
svm->train(trainData, ROW_SAMPLE, labels); //训练
cout << "Finished training process" << endl;
几个地方需要注意一下:
- 1.SVM类型选择:setType(SVM::C_SVC)
SVM模型类型枚举:
enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };
如果只是用ml::SVM,完成简单的数据分类,则可以这个参数设置为C_SVC.
标志 | 作用 |
---|---|
C_SVC | C表示惩罚因子,C越大表示对错误分类的惩罚越大 |
NU_SVC | 和C_SVC类似(参数C的范围不同,C_SVC采用的是0到正无穷,NU_SVC是[0,1]。) |
ONE_CLASS | 需要类标号,用于支持向量的密度估计和聚类. |
EPSILON_SVR | 敏感损失函数,对样本点来说,存在着一个不为目标函数提供任何损失值的区域,即-带。 |
NU_SVR | 由于EPSILON_SVR需要事先确定参数,然而在某些情况下选择合适的参数却不是一件容易的事情。而NU_SVR能够自动计算参数。 |
2.惩罚因子设置: svm->setC(0.1)
①大的C(惩罚因子)使得解决方案更小的误分类率和更小的间隔。考虑到在这种情况下,分类错误的代价是昂贵的。因为优化的目的是减小不同,更小的误分类率是被接受的
②小的C(惩罚因子)使得解决方案有更大的分类间隔和更大的误分类率。这种情况下,最小化没有考虑太多总和,因此它更注重于发现有更大间隔的高维空间。3.选择核函数类型
简要说一下,核函数在SVM的作用。
SVM是一个二分类器,目的是找到一个平面,将已知的数据分为两类。以上代码已经实现了对二维空间中数据的分类,二维空间中的平面即为分类的直线(又叫做支持向量)。设想数据复杂一点,虽然维度还是三维,可是已经不能找到简单的平面来对数据进行分类。见图从人眼来看,很明显可以用一个圆来对两类数据进行分类,而SVM所定义的最佳分类平面在二维空间中是直线。所以核函数出现了,它能将二维空间变化到高维空间。现在将两个圆坐标变换到三维空间,经过简单的旋转变换,可以用平面对数据进行分类,可见下图(来自pluskid:下面的gif 动画,先用 Matlab 画出一张张图片,再用 Imagemagick 拼贴成)
好了,回到opencv中的SVM核函数:
enum {LINEAR, POLY, RBF,SIGMOID, PRECOMPUTED };
LINEAR:线性核函数(linear kernel)
POLY:多项式核函数(ploynomial kernel)
RBF:径向机核函数(radical basis function)
SIGMOID:神经元的非线性作用函数核函数(Sigmoid tanh)根据数据的特性选择不同核函数,进行SVM分类。
4.训练
svm->train(trainData, ROW_SAMPLE, labels);
trainData:训练数据
labels:数据类型几率
关于ROW_SAMPLE,见ml.hpp中的定义(添加注释)
enum SampleTypes
{
//每一行为一个样本数据
ROW_SAMPLE = 0, //!< each training sample is a row of samples
//每一列为一个样本数据
COL_SAMPLE = 1 //!< each training sample occupies a column of samples
};
知道以上内容后,我们就可以利用opencv中的SVM进行简单的数据分类。
————–2017.3.8