关于SVM 算法的详解可以看点击打开链接,下面的svm算法实现基于OpenCv的CvSvm实现:
#ifndef __SVM_TRAIN__
#define __SVM_TRAIN__
#include<opencv2\opencv.hpp>
#include <opencv2/core/core.hpp>
//#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>
//#include"ml.h"
typedef struct
{
CvMat *data_mat;
CvMat *class_mat; //数据的类别
}Vector_data,*vector_data;
class MySvm: public CvSVM
{
public:
int get_alpha_count()
{
return this->sv_total;
}
int get_sv_dim()
{
return this->var_all;
}
int get_sv_count()
{
return this->decision_func->sv_count;
}
double* get_alpha()
{
return this->decision_func->alpha;
}
float** get_sv()
{
return this->sv;
}
float get_rho()
{
return this->decision_func->rho;
}
};
class svm_Train
{
public:
svm_Train(int sample_count,int sample_size);
~svm_Train();
void svm_SetData(float** data, float* label);
void svm_StartTrain();
void svm_Save();
float svm_Perdict(CvMat* TestMat );
vector_data vector_element;
MySvm* tsvm;
CvSVMParams svm_param;
int sample_count;
int sample_size;
bool isTrain,isSetData,isSave;
};
#endif
#include"svm_Train.h"
svm_Train::svm_Train(int sample_count,int sample_size)
{
this->vector_element=(vector_data)malloc(sizeof(Vector_data));
this->vector_element->data_mat=cvCreateMat(sample_count,sample_size,CV_32FC1);
this->vector_element->class_mat=cvCreateMat(sample_count,1,CV_32FC1);
/*this->vector_element->data_mat=cvCreateMat(sample_size,sample_count,CV_32FC1);
this->vector_element->class_mat=cvCreateMat(1,sample_count,CV_32FC1);*/
this->tsvm=new MySvm();
this->sample_count=sample_count;
this->sample_size=sample_size;
//
//this->svm_param.svm_type = CvSVM::C_SVC;
///*this->svm_param.svm_type = CvSVM::C_SVC;*/
this->svm_param.kernel_type = CvSVM::LINEAR;
///*this->svm_param.kernel_type = CvSVM::RBF;*/
//this->svm_param.kernel_type =CvSVM::POLY;
this->svm_param.kernel_type =CvSVM::SIGMOID;
//this->svm_param.gamma = 1./this->sample_size;
//this->svm_param.nu = 0.5;
//this->svm_param.C = 10;
//this->svm_param.term_crit.epsilon = 0.0001;
//this->svm_param.term_crit.max_iter = 1000;
//this->svm_param.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
//this->svm_param.coef0=4;
//this->svm_param.degree=4;
this->svm_param.svm_type = CvSVM::EPS_SVR;
this->svm_param.kernel_type = CvSVM::RBF;
/*this->svm_param.kernel_type = CvSVM::POLY; */
this->svm_param.C =300;
this->svm_param.p = 1e-3;
this->svm_param.gamma =0.29;//1./this->sample_size;
/*this->svm_param.coef0=3;
this->svm_param.degree=3.8;*/
/*this->svm_param.svm_type = CvSVM::NU_SVR;
this->svm_param.kernel_type = CvSVM::LINEAR;
this->svm_param.C =1;
this->svm_param.p = 1e-3;
this->svm_param.gamma = 5;
this->svm_param.nu= 0.9;*/
isTrain=false;
isSetData=false;
isSave=false;
}
void svm_Train::svm_SetData( float** data, float* label)
{
/*cvInitMatHeader(vector_element->class_mat,sample_count,1,CV_32FC1,label);
cvInitMatHeader(vector_element->data_mat,sample_count,sample_size,CV_32FC1,data); */
//cvInitMatHeader(&mtestd,1,2,CV_32FC1,testd);
for(int i=0;i<sample_count;i++)
{
vector_element->class_mat->data.fl[i]=label[i];
for(int j=0;j<sample_size;j++)
{
vector_element->data_mat->data.fl[i*sample_size+j]=data[i][j];
}
}
/*for(int i=0;i<sample_count;i++)
{
vector_element->class_mat->data.fl[i]=label[i];
for(int j=0;j<sample_size;j++)
{
vector_element->data_mat->data.fl[j*sample_count+i]=data[i][j];
}
}*/
isSetData=true;
}
void svm_Train::svm_StartTrain()
{
if(isSetData)
{
tsvm->train(vector_element->data_mat, vector_element->class_mat,0,0,svm_param);
isTrain=true;
}
else
printf("Please set train data first\n");
}
void svm_Train::svm_Save()
{
if(isTrain)
tsvm->save("detector.xml", 0); //For prediction
else
printf("Please training data first\n");
isSave=true;
}
float svm_Train::svm_Perdict(CvMat* TestMat)
{
float value;
if(isTrain)
{
value=tsvm->predict(TestMat);
}
else
printf("Please training data first\n");
return value;
}
svm_Train::~svm_Train()
{
isTrain=false;
isSetData=false;
isSave=false;
if(vector_element) free(vector_element);
if(tsvm) free(tsvm);
}