今早刚接触一个新的库——dlib(http://dlib.net),讲真,真的很好用。按照官方的介绍,就是:These wrappers provide a portable object oriented interface for networking, multithreading, GUI development, and file browsing. Programs written using them can be compiled under POSIX or MS Windows platforms without changing the code.也就是说,DLIB是一个C ++库,用于开发可移植的应用程序与网络处理,线程,图形界面,数据结构,线性代数,机器学习,XML和文本解析,数值优化,贝叶斯网,和许多其他任务。几乎涉及到数据分析的方方面面了。更重要的是,类似于openCV,它提供很多很详细的example,因此学习起来应该不难。由于今天第一天接触,就根据svm分类的example把它改写成了一个二类分类库,多类分类器以后再慢慢加进去。不过,功能应该不太完善。总之,先放上来吧,以后再慢慢改,目前是涉及到nu和C参数的调整,默认是对ganmma和C调参,因为这两个对结果影响最大嘛。
#include <QtCore>
#include <iostream>
#include <dlib/svm.h>
#include "dlib/rand/rand_kernel_abstract.h"
using namespace std;
using namespace dlib;
//svm二类分类器,调用前请修改nFeatures值;
namespace SVM{
#define nFeatures 2
typedef matrix<double, nFeatures, 1> sample_type;//定义数据类型;
typedef radial_basis_kernel<sample_type> kernel_type;//定义核类型;
typedef probabilistic_decision_function<kernel_type> probabilistic_funct_type;
typedef normalized_function<probabilistic_funct_type> pfunct_type;
enum Trainer{CTrainer = 1, NUTrainer = 2};
enum LoadType {LoadSamples = 1, LoadTestData = 2};
class SVMClassification{
public:
SVMClassification(){
samples.clear();
labels.clear();
}
~SVMClassification(){}
bool loadData(const char* fn, int opt = LoadSamples)
{
if(! QFile::exists(fn))
{
cout << fn << "does not exist!\n";
return false;
}
QFile infile(fn);
if (!infile.open(QIODevice::ReadOnly))
{
cout << fn << "open error!\n";
return false;
}
QTextStream _in(&infile);
QString smsg = _in.readLine();
QStringList slist;
if(opt == LoadSamples)
{
samples.clear();
labels.clear();
}
else
testData.clear();
while(! _in.atEnd())
{
sample_type samp;
smsg = _in.readLine();
slist = smsg.split(",");
for (int i = 0; i < nFeatures; i ++)
{
samp(i) = slist[i+1].trimmed().toDouble();
//cout << samp(i)<<" ";
}
if(opt == LoadSamples)
{
samples.push_back(samp);
labels.push_back(slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0);
//cout << (slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0)<<endl;
}
else
testData.push_back(samp);
}
infile.close();
return true;
}
//生成随机的样本数据;
bool generateRandomSamples(int num)
{
dlib::rand Rd;
for (int i = 0; i < num; i ++)
{
sample_type samp;
for (int j = 0; j < nFeatures; j ++)
{
samp(j) = Rd.get_random_gaussian();
}
samples.push_back(samp);
double _label = (double)(Rd.get_random_16bit_number()%2);
if(_label == 0)
_label =-1;
labels.push_back(_label);
}
cout << "randomly generated "<<num<<" samples\n";
return true;
}
bool normalization()
{
//归一化;
normalizer.train(samples);//获取均值和方差;
for (unsigned long i = 0; i < samples.size(); ++i)
samples[i] = normalizer(samples[i]);
//将样本数据打乱,以用于多次交叉验证;
randomize_samples(samples, labels);
return true;
}
//计算最优的参数-gammma, nu;
bool findBestParam(int opt = CTrainer)
{
//根据正负标签比例计算参数nu的最大值;
const double max_nu = maximum_nu(labels);
cout << "max_nu = "<< max_nu <<endl;
cout << "doing cross validation..." << endl;
matrix<double> best_result(1, 2);
best_result = 0;
best_gamma = 0.0001, best_nu = 0.0001, best_c = 5;
switch(opt)
{
case NUTrainer:
for (double gamma = 0.00001; gamma <= 1; gamma *= 5)
{
for (double nu = 0.00001; nu < max_nu; nu *= 5)
{
trainer.set_kernel(kernel_type(gamma));
trainer.set_nu(nu);
cout << "gamma: " << gamma << " nu: " << nu;
matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);
cout << " cross validation accuracy: " << result;
if (sum(result) > sum(best_result))
{
best_result = result;
best_gamma = gamma;
best_nu = nu;
}
}
}
cout << "\nbest gamma: " << best_gamma <<" best nu: " << best_nu<< " best score: "<<best_result<<"mean acc: "<<mean(best_result) << endl;
break;
case CTrainer:
for (double gamma = 0.00001; gamma <= 1; gamma *= 5)
{
for (double _c = 1; _c < 2000; _c *= 2)
{
c_trainer.set_kernel(kernel_type(gamma));
c_trainer.set_c(_c);
cout << "gamma: " << gamma << " C: " << _c;
matrix<double> result = cross_validate_trainer(c_trainer, samples, labels, 10);
cout << " cross validation accuracy: " << result;
if (sum(result) > sum(best_result))
{
best_result = result;
best_gamma = gamma;
best_c = _c;
}
}
}
cout << "\nbest gamma: " << best_gamma <<" best c: " << best_c<< " best score: "<<best_result<<"mean acc: "<<mean(best_result) << endl;
break;
}
return true;
}
void setGamma(double _gamma)
{
best_gamma = _gamma;
}
void setNu(double _nu)
{
best_nu = _nu;
}
void setC(double _c)
{
best_c = _c;
}
bool addSample(double* _pData, double _label)
{
sample_type samp;
for (int i = 0; i < nFeatures; i ++)
{
samp(i) = _pData[i];
}
samples.push_back(samp);
labels.push_back(_label);
return true;
}
//clear the samples and labels and reset from a 2-D array;
bool setSamples(double** _ppSamples, int _nsamples)
{
samples.clear();
labels.clear();
for (int i = 0; i < _nsamples; i ++)
{
sample_type samp;
for (int j = 0; j < nFeatures; j ++)
{
samp(j) = _ppSamples[i][j];
}
samples.push_back(samp);
labels.push_back(_ppSamples[i][nFeatures]);
}
return true;
}
//学习训练分类器;
bool learnFunc(int opt = CTrainer)
{
switch(opt)
{
case CTrainer:
c_trainer.set_kernel(kernel_type(best_gamma));
c_trainer.set_c(best_nu);
learned_pfunct.normalizer = normalizer;
learned_pfunct.function = train_probabilistic_decision_function(c_trainer, samples, labels, 3);
break;
case NUTrainer:
trainer.set_kernel(kernel_type(best_gamma));
trainer.set_nu(best_nu);
learned_pfunct.normalizer = normalizer;
learned_pfunct.function = train_probabilistic_decision_function(trainer, samples, labels, 3);
break;
}
cout << "\nnumber of support vectors in our learned_pfunct is "
<< learned_pfunct.function.decision_funct.basis_vectors.size() << endl;
return true;
}
double predictProbability(sample_type _samp)
{
return learned_pfunct(_samp);
}
//预测概率;
double predictProbability(double* _val)
{
sample_type samp;
for (int i = 0; i<nFeatures; i ++)
samp(i) = _val[i];
return learned_pfunct(samp);
}
//将分类器另存为文件;
bool saveLearnedFunc(const char* fn)
{
serialize(fn) << learned_pfunct;
cout <<"saved learned function to "<< fn<<endl;
return true;
}
//从文件中读取分类器;
bool loadLearnedFunc(const char* fn)
{
deserialize(fn) >> learned_pfunct;
cout <<"loaded learned function from "<< fn<<endl;
return true;
}
//用一定数目的支持向量来交叉验证的精度结果;
bool getAccByCrossValidateWithVectors(int nVectors, int opt = CTrainer)
{
cout << "\ncross validation accuracy with only "<<nVectors<<" support vectors: " ;
switch(opt)
{
case CTrainer:
cout << cross_validate_trainer(reduced2(c_trainer, nVectors), samples, labels, 3);
break;
case NUTrainer:
cout << cross_validate_trainer(reduced2(trainer, nVectors), samples, labels, 3);
break;
}
return true;
}
private:
std::vector<sample_type> samples;
std::vector<double> labels;
std::vector<sample_type> testData;
svm_nu_trainer<kernel_type> trainer;
svm_c_trainer<kernel_type> c_trainer;
vector_normalizer<sample_type> normalizer;
double best_gamma;
double best_nu;
double best_c;
pfunct_type learned_pfunct;
protected:
};
}