最近遇到一个多分类的问题,在网上查了些有关SVM的资料,这篇日志是来自:http://wenku.baidu.com/view/81c3e210f18583d0496459f0.html 自己写代码实现了下,感觉很好使,作为一个学习的例子,放在自己的博客里,供以后查阅使用。
-
#include “stdafx.h”
-
#include “cv.h”
-
#include “highgui.h”
-
#include “ml.h”
-
#include <TIME.H>
-
#include <CTYPE.H>
-
#include <IOSTREAM>
-
-
using
namespace
std;
-
-
int main(int argc, char* argv[])
-
{
-
int size =
400;
// height and widht of image
-
const
int s =
1000;
// number of data
-
int i, j,sv_num;
-
IplImage* img;
-
-
CvSVM svm = CvSVM();
-
CvSVMParams param;
-
CvTermCriteria criteria;
// 停止迭代标准
-
CvRNG rng = cvRNG(time(
NULL));
-
CvPoint pts[s];
// 定义1000个点
-
float data[s*
2];
// 点的坐标
-
int res[s];
// 点的类别
-
-
CvMat data_mat, res_mat;
-
CvScalar rcolor;
-
-
const
float* support;
-
-
// 图像区域的初始化
-
img = cvCreateImage(cvSize(size,size),IPL_DEPTH_8U,
3);
-
cvZero(img);
-
-
// 学习数据的生成
-
for (i=
0; i<s;++i)
-
{
-
pts[i].x = cvRandInt(&rng)%size;
-
pts[i].y = cvRandInt(&rng)%size;
-
-
if (pts[i].y>
50*
cos(pts[i].x*CV_PI/
100)+
200)
-
{
-
cvLine(img,cvPoint(pts[i].x
-2,pts[i].y
-2),cvPoint(pts[i].x+
2,pts[i].y+
2),CV_RGB(
255,
0,
0));
-
cvLine(img,cvPoint(pts[i].x+
2,pts[i].y
-2),cvPoint(pts[i].x
-2,pts[i].y+
2),CV_RGB(
255,
0,
0));
-
res[i]=
1;
-
}
-
else
-
{
-
if (pts[i].x>
200)
-
{
-
cvLine(img,cvPoint(pts[i].x
-2,pts[i].y
-2),cvPoint(pts[i].x+
2,pts[i].y+
2),CV_RGB(
0,
255,
0));
-
cvLine(img,cvPoint(pts[i].x+
2,pts[i].y
-2),cvPoint(pts[i].x
-2,pts[i].y+
2),CV_RGB(
0,
255,
0));
-
res[i]=
2;
-
}
-
else
-
{
-
cvLine(img,cvPoint(pts[i].x
-2,pts[i].y
-2),cvPoint(pts[i].x+
2,pts[i].y+
2),CV_RGB(
0,
0,
255));
-
cvLine(img,cvPoint(pts[i].x+
2,pts[i].y
-2),cvPoint(pts[i].x
-2,pts[i].y+
2),CV_RGB(
0,
0,
255));
-
res[i]=
3;
-
}
-
}
-
}
-
-
// 学习数据的现实
-
cvNamedWindow(
“SVM”,CV_WINDOW_AUTOSIZE);
-
cvShowImage(
“SVM”,img);
-
cvWaitKey(
0);
-
-
// 学习参数的生成
-
for (i=
0;i<s;++i)
-
{
-
data[i*
2] =
float(pts[i].x)/size;
-
data[i*
2+
1] =
float(pts[i].y)/size;
-
}
-
-
cvInitMatHeader(&data_mat,s,
2,CV_32FC1,data);
-
cvInitMatHeader(&res_mat,s,
1,CV_32SC1,res);
-
criteria = cvTermCriteria(CV_TERMCRIT_EPS,
1000,FLT_EPSILON);
-
param = CvSVMParams(CvSVM::C_SVC,CvSVM::RBF,
10.0,
8.0,
1.0,
10.0,
0.5,
0.1,
NULL,criteria);
-
-
// SVM type:CvSVM::C_SVC Kernel:CvSVM::RBF degree:10.0 gamma:8.0 coef0:1.0
-
-
svm.train(&data_mat,&res_mat,
NULL,
NULL,param);
-
-
// 学习结果绘图
-
for (i=
0;i<size;i++)
-
{
-
for (j=
0;j<size;j++)
-
{
-
CvMat m;
-
float ret =
0.0;
-
float a[] = {
float(j)/size,
float(i)/size};
-
cvInitMatHeader(&m,
1,
2,CV_32FC1,a);
-
ret = svm.predict(&m);
-
-
switch((
int)ret)
-
{
-
case
1:
-
rcolor = CV_RGB(
100,
0,
0);
-
break;
-
case
2:
-
rcolor = CV_RGB(
0,
100,
0);
-
break;
-
case
3:
-
rcolor = CV_RGB(
0,
0,
100);
-
break;
-
}
-
cvSet2D(img,i,j,rcolor);
-
}
-
}
-
-
-
// 为了显示学习结果,通过对输入图像区域的所有像素(特征向量)进行分类,然后对输入的像素用所属颜色等级的颜色绘图
-
for(i=
0;i<s;++i)
-
{
-
CvScalar rcolor;
-
switch(res[i])
-
{
-
case
1:
-
rcolor = CV_RGB(
255,
0,
0);
-
break;
-
case
2:
-
rcolor = CV_RGB(
0,
255,
0);
-
break;
-
case
3:
-
rcolor = CV_RGB(
0,
0,
255);
-
break;
-
}
-
cvLine(img,cvPoint(pts[i].x
-2,pts[i].y
-2),cvPoint(pts[i].x+
2,pts[i].y+
2),rcolor);
-
cvLine(img,cvPoint(pts[i].x+
2,pts[i].y
-2),cvPoint(pts[i].x
-2,pts[i].y+
2),rcolor);
-
}
-
-
// 支持向量的绘制
-
sv_num = svm.get_support_vector_count();
-
for (i=
0; i<sv_num;++i)
-
{
-
support = svm.get_support_vector(i);
-
cvCircle(img,cvPoint((
int)(support[
0]*size),(
int)(support[i]*size)),
5,CV_RGB(
200,
200,
200));
-
}
-
-
cvNamedWindow(
“SVM”,CV_WINDOW_AUTOSIZE);
-
cvShowImage(
“SVM”,img);
-
cvWaitKey(
0);
-
cvDestroyWindow(
“SVM”);
-
cvReleaseImage(&img);
-
-
return
0;
-
}
</div>
</div>