为部分代码,只做参考。文中很多变量类型为自己定义的数据结构。遗憾的是,纯C实现的SVM代码找不到了,有空再写一个吧
头文件:
#ifndef SVM_C_H
#define SVM_C_H
#include"Process.h"
extern void Label(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample);//标签赋值;
extern void Train();
extern void Test();
extern void Classify();
extern void ParamsSelection(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample);
extern void Train_opencv_ori();
extern void Train_opencv_opt();
extern void Classify_opencv(int mode);
void SVM_Params_ori();
void SVM_Params_opt();
#endif;
cpp文件:
#include "StdAfx.h"
#include"Svm_c.h"
dlib::svm_c_trainer<kernel_type>trainer;
std::vector<sample_type>AllSamples;
std::vector<double>All_labels;
funct_type learned_function;
dlib::vector_normalizer<sample_type>normalizer;
dlib::rand rnd;
//cv::Mat Classes;
CvSVMParams SVM_params;
CvSVM svm;
int respones;
int PrePoNum=0;
int PreNgNum=0;
std::vector<int>PSIndex;//分为正样本的索引;
std::vector<int>NGIndex;//分为负样本的索引;
CvParamGrid nuGrid;
CvParamGrid coeffGrid;
CvParamGrid degreeGrid;
void Label(std::vector<sample_type> &PSample,std::vector<sample_type> &NSample)
{
std::cout<<"labeling..."<<std::endl;
AllSamples.clear();
int PSnum=PSample.size();
int NGnum=NSample.size();
int Num=0;
Num=PSnum+NGnum;
if(PSnum>0&&NGnum>0)
{
for(int i=0;i<Num;++i)
{
if(i<PSnum)
{
AllSamples.push_back(PSample[i]);
All_labels.push_back(1);
}
else
{
AllSamples.push_back(NSample[i-PSnum]);
All_labels.push_back(-1);
}
}
normalizer.train(AllSamples);
for(unsigned long i=0;i<AllSamples.size();++i)
{
AllSamples[i]=normalizer(AllSamples[i]);
std::cout<<AllSamples[i](0)<<" "<<AllSamples[i](1)<<" "<<AllSamples[i](2)<<" "<<AllSamples[i](3)<<std::endl;
}
dlib::randomize_samples(AllSamples,All_labels);//可要可不要;
}
else
std::cout<<"训练样本无效!"<<std::endl;
}
void Train()
{
if(AllSamples.size()>0)
{
std::cout<<"Doing cross calidation"<<std::endl;
for(double gamma=0.00001;gamma<=1;gamma*=5)
{
for(double C=1;C<100000;C*=5)
{
trainer.set_kernel(kernel_type(gamma));
trainer.set_c(C);
std::cout<<"gamma: "<<gamma<<" C: "<<C;
std::cout<<" cross validation accuracy: "<<dlib::cross_validate_trainer(trainer,AllSampl