Svm算法

关于SVM 算法的详解可以看点击打开链接,下面的svm算法实现基于OpenCv的CvSvm实现:


#ifndef __SVM_TRAIN__
#define __SVM_TRAIN__
#include<opencv2\opencv.hpp>
#include <opencv2/core/core.hpp>
//#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>
//#include"ml.h"
typedef struct 
{
	 CvMat *data_mat;
	 CvMat *class_mat;  //数据的类别

}Vector_data,*vector_data;
class MySvm: public CvSVM  
{  
public:  
    int get_alpha_count()  
    {  
        return this->sv_total;  
    }  
  
   int get_sv_dim()  
    {  
        return this->var_all;  
    }  
  
   int get_sv_count()  
    {  
       return this->decision_func->sv_count;  
   }  
 
   double* get_alpha()  
   {  
       return this->decision_func->alpha;  
    }  
  
    float** get_sv()  
    {  
        return this->sv;  
    }  
  
    float get_rho()  
    {  
        return this->decision_func->rho;  
    }  
}; 

class svm_Train
{
public:
	svm_Train(int sample_count,int sample_size);
	~svm_Train();
	void svm_SetData(float** data, float* label);
	void svm_StartTrain();
	void svm_Save();
	float svm_Perdict(CvMat* TestMat );
	vector_data vector_element;
	MySvm* tsvm;
	CvSVMParams svm_param;
	int sample_count;
	int sample_size;
	bool isTrain,isSetData,isSave;
};

#endif

#include"svm_Train.h"
svm_Train::svm_Train(int sample_count,int sample_size)
{
	this->vector_element=(vector_data)malloc(sizeof(Vector_data));
	this->vector_element->data_mat=cvCreateMat(sample_count,sample_size,CV_32FC1);
	this->vector_element->class_mat=cvCreateMat(sample_count,1,CV_32FC1);
	/*this->vector_element->data_mat=cvCreateMat(sample_size,sample_count,CV_32FC1);
	this->vector_element->class_mat=cvCreateMat(1,sample_count,CV_32FC1);*/
	this->tsvm=new MySvm();
	this->sample_count=sample_count;
	this->sample_size=sample_size;

	//
	//this->svm_param.svm_type = CvSVM::C_SVC; 
	///*this->svm_param.svm_type = CvSVM::C_SVC;*/ 
	this->svm_param.kernel_type = CvSVM::LINEAR;
	///*this->svm_param.kernel_type = CvSVM::RBF;*/
	//this->svm_param.kernel_type =CvSVM::POLY;
	this->svm_param.kernel_type =CvSVM::SIGMOID;
	//this->svm_param.gamma = 1./this->sample_size;    
	//this->svm_param.nu = 0.5;    
	//this->svm_param.C = 10;    
	//this->svm_param.term_crit.epsilon = 0.0001;    
	//this->svm_param.term_crit.max_iter = 1000;    
	//this->svm_param.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS; 
	//this->svm_param.coef0=4;
	//this->svm_param.degree=4;

    this->svm_param.svm_type = CvSVM::EPS_SVR;
	this->svm_param.kernel_type = CvSVM::RBF;
    /*this->svm_param.kernel_type = CvSVM::POLY;  */
    this->svm_param.C =300;  
    this->svm_param.p = 1e-3;  
    this->svm_param.gamma =0.29;//1./this->sample_size; 
	/*this->svm_param.coef0=3;
	this->svm_param.degree=3.8;*/

	/*this->svm_param.svm_type = CvSVM::NU_SVR;  
    this->svm_param.kernel_type = CvSVM::LINEAR;  
    this->svm_param.C =1;  
    this->svm_param.p = 1e-3;  
    this->svm_param.gamma = 5;
	this->svm_param.nu= 0.9;*/

	isTrain=false;
	isSetData=false;
	isSave=false;
}

void svm_Train::svm_SetData( float** data, float* label)
{
	/*cvInitMatHeader(vector_element->class_mat,sample_count,1,CV_32FC1,label);  
    cvInitMatHeader(vector_element->data_mat,sample_count,sample_size,CV_32FC1,data);  */
    //cvInitMatHeader(&mtestd,1,2,CV_32FC1,testd);
	for(int i=0;i<sample_count;i++)
	{
	 vector_element->class_mat->data.fl[i]=label[i];
	 for(int j=0;j<sample_size;j++)
	{
     vector_element->data_mat->data.fl[i*sample_size+j]=data[i][j];
	}
	}

	/*for(int i=0;i<sample_count;i++)
	{
	 vector_element->class_mat->data.fl[i]=label[i];
	 for(int j=0;j<sample_size;j++)
	{
     vector_element->data_mat->data.fl[j*sample_count+i]=data[i][j];
	}
	}*/

	isSetData=true;
	 
}

void svm_Train::svm_StartTrain()
{
	if(isSetData)
	{
	tsvm->train(vector_element->data_mat, vector_element->class_mat,0,0,svm_param); 
	isTrain=true;
	}
	else
	printf("Please set train data first\n");
	
}

void svm_Train::svm_Save()
{
	if(isTrain)
	tsvm->save("detector.xml", 0); //For prediction
	else
	printf("Please training data first\n");
	isSave=true;
}

float svm_Train::svm_Perdict(CvMat* TestMat)
{
	float value;
	if(isTrain)
	{
	value=tsvm->predict(TestMat);
	}
	else
	printf("Please training data first\n");
	return value;
}
svm_Train::~svm_Train()
{
	isTrain=false;
	isSetData=false;
	isSave=false;
	if(vector_element) free(vector_element);
	if(tsvm) free(tsvm);
}


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值