一个偷偷写的svm库

今早刚接触一个新的库——dlib(http://dlib.net),讲真,真的很好用。按照官方的介绍,就是:These wrappers provide a portable object oriented interface for networking, multithreading, GUI development, and file browsing. Programs written using them can be compiled under POSIX or MS Windows platforms without changing the code.也就是说,DLIB是一个C ++库,用于开发可移植的应用程序与网络处理,线程,图形界面,数据结构,线性代数,机器学习,XML和文本解析,数值优化,贝叶斯网,和许多其他任务。几乎涉及到数据分析的方方面面了。更重要的是,类似于openCV,它提供很多很详细的example,因此学习起来应该不难。由于今天第一天接触,就根据svm分类的example把它改写成了一个二类分类库,多类分类器以后再慢慢加进去。不过,功能应该不太完善。总之,先放上来吧,以后再慢慢改,目前是涉及到nu和C参数的调整,默认是对ganmma和C调参,因为这两个对结果影响最大嘛。


#include <QtCore>
#include <iostream>
#include <dlib/svm.h>
#include "dlib/rand/rand_kernel_abstract.h"

using namespace std;
using namespace dlib;

//svm二类分类器,调用前请修改nFeatures值;
namespace SVM{

	#define nFeatures 2
	typedef matrix<double, nFeatures, 1> sample_type;//定义数据类型;
	typedef radial_basis_kernel<sample_type> kernel_type;//定义核类型;

	typedef probabilistic_decision_function<kernel_type> probabilistic_funct_type;  
	typedef normalized_function<probabilistic_funct_type> pfunct_type;

	enum Trainer{CTrainer = 1, NUTrainer = 2};
	enum LoadType {LoadSamples = 1, LoadTestData = 2};

	class SVMClassification{
	public:
		SVMClassification(){
			samples.clear();
			labels.clear();
		}
		~SVMClassification(){}
		
		bool loadData(const char* fn, int opt = LoadSamples)
		{
			if(! QFile::exists(fn))
			{
				cout << fn << "does not exist!\n";
				return false;
			}
			QFile infile(fn);
			if (!infile.open(QIODevice::ReadOnly))
			{
				cout << fn << "open error!\n";
				return false;
			}
			QTextStream _in(&infile);
			QString smsg = _in.readLine();
			QStringList slist;
			if(opt == LoadSamples)
			{
				samples.clear();
				labels.clear();
			}
			else
				testData.clear();
			

			while(! _in.atEnd())
			{
				sample_type samp;
				smsg = _in.readLine();
				slist = smsg.split(",");
				for (int i = 0; i < nFeatures; i ++)
				{
					samp(i) = slist[i+1].trimmed().toDouble();
					//cout << samp(i)<<" ";
				}
				
				if(opt == LoadSamples)
				{
					samples.push_back(samp);
					labels.push_back(slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0);
					//cout << (slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0)<<endl;
				}
				else
					testData.push_back(samp);
			}
			infile.close();
			return true;
		}
		//生成随机的样本数据;
		bool generateRandomSamples(int num)
		{
			dlib::rand Rd;
			for (int i = 0; i < num; i ++)
			{
				sample_type samp;
				for (int j = 0; j < nFeatures; j ++)
				{
					samp(j) = Rd.get_random_gaussian();
				}
				samples.push_back(samp);
				double _label = (double)(Rd.get_random_16bit_number()%2);
				if(_label == 0)
					_label =-1;
				labels.push_back(_label);
			}
			cout << "randomly generated "<<num<<" samples\n";
			return true;
		}
		bool normalization()
		{
			//归一化;
			normalizer.train(samples);//获取均值和方差;
			for (unsigned long i = 0; i < samples.size(); ++i)
				samples[i] = normalizer(samples[i]); 
			//将样本数据打乱,以用于多次交叉验证;
			randomize_samples(samples, labels);
			return true;
		}
		//计算最优的参数-gammma, nu;
		bool findBestParam(int opt = CTrainer)
		{
			//根据正负标签比例计算参数nu的最大值;
			const double max_nu = maximum_nu(labels);
			cout << "max_nu = "<< max_nu <<endl;

			cout << "doing cross validation..." << endl;
			matrix<double> best_result(1, 2);
			best_result = 0;
			best_gamma = 0.0001, best_nu = 0.0001, best_c = 5;

			switch(opt)
			{
			case NUTrainer:
				for (double gamma = 0.00001; gamma <= 1; gamma *= 5)
				{
					for (double nu = 0.00001; nu < max_nu; nu *= 5)
					{
						trainer.set_kernel(kernel_type(gamma));
						trainer.set_nu(nu);

						cout << "gamma: " << gamma << "    nu: " << nu;
						matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);
						cout << "     cross validation accuracy: " << result;

						if (sum(result) > sum(best_result))
						{
							best_result = result;
							best_gamma = gamma;
							best_nu = nu;
						} 
					}
				}
				cout << "\nbest gamma: " << best_gamma <<"      best nu: " << best_nu<< "      best score: "<<best_result<<"mean acc:  "<<mean(best_result) << endl; 
				break;
			case CTrainer:
				for (double gamma = 0.00001; gamma <= 1; gamma *= 5)
				{
					for (double _c = 1; _c < 2000; _c *= 2)
					{
						c_trainer.set_kernel(kernel_type(gamma));
						c_trainer.set_c(_c);

						cout << "gamma: " << gamma << "    C: " << _c;
						matrix<double> result = cross_validate_trainer(c_trainer, samples, labels, 10);
						cout << "     cross validation accuracy: " << result;

						if (sum(result) > sum(best_result))
						{
							best_result = result;
							best_gamma = gamma;
							best_c = _c;
						} 
					}
				}
				cout << "\nbest gamma: " << best_gamma <<"      best c: " << best_c<< "      best score: "<<best_result<<"mean acc:  "<<mean(best_result) << endl; 
				break;
			}
			return true;
		}
		
		void setGamma(double _gamma)
		{
			best_gamma = _gamma;
		}

		void setNu(double _nu)
		{
			best_nu = _nu;
		}

		void setC(double _c)
		{
			best_c = _c;
		}

		bool addSample(double* _pData, double _label)
		{
			sample_type samp;
			for (int i = 0; i < nFeatures; i ++)
			{
				samp(i) = _pData[i];
			}
			samples.push_back(samp);
			labels.push_back(_label);
			return true;
		}

		//clear the samples and labels and reset from a 2-D array;
		bool setSamples(double** _ppSamples, int _nsamples)
		{
			samples.clear();
			labels.clear();
			for (int i = 0; i < _nsamples; i ++)
			{
				sample_type samp;
				for (int j = 0; j < nFeatures; j ++)
				{
					samp(j) = _ppSamples[i][j];
				}
				samples.push_back(samp);
				labels.push_back(_ppSamples[i][nFeatures]);
			}
			return true;
		}

		//学习训练分类器;
		bool learnFunc(int opt = CTrainer)
		{
			switch(opt)
			{
			case CTrainer:
				c_trainer.set_kernel(kernel_type(best_gamma));
				c_trainer.set_c(best_nu);

				learned_pfunct.normalizer = normalizer;
				learned_pfunct.function = train_probabilistic_decision_function(c_trainer, samples, labels, 3);
				break;
			case NUTrainer:
				trainer.set_kernel(kernel_type(best_gamma));
				trainer.set_nu(best_nu);

				learned_pfunct.normalizer = normalizer;
				learned_pfunct.function = train_probabilistic_decision_function(trainer, samples, labels, 3);
				break;
			}
			
			cout << "\nnumber of support vectors in our learned_pfunct is " 
				<< learned_pfunct.function.decision_funct.basis_vectors.size() << endl;
			return true;
		}

		double predictProbability(sample_type _samp)
		{
			return learned_pfunct(_samp);
		}
		//预测概率;
		double predictProbability(double* _val)
		{
			sample_type samp;
			for (int i = 0; i<nFeatures; i ++)
				samp(i) = _val[i];
			return learned_pfunct(samp);
		}
		//将分类器另存为文件;
		bool saveLearnedFunc(const char* fn)
		{
			 serialize(fn) << learned_pfunct;
			 cout <<"saved learned function to "<< fn<<endl;
			 return true;
		}
		//从文件中读取分类器;
		bool loadLearnedFunc(const char* fn)
		{
			deserialize(fn) >> learned_pfunct;
			cout <<"loaded learned function from "<< fn<<endl;
			return true;
		}
		//用一定数目的支持向量来交叉验证的精度结果;
		bool getAccByCrossValidateWithVectors(int nVectors, int opt = CTrainer)
		{
			cout << "\ncross validation accuracy with only "<<nVectors<<" support vectors: " ;
			switch(opt)
			{
			case CTrainer:
				cout << cross_validate_trainer(reduced2(c_trainer, nVectors), samples, labels, 3);
				break;
			case NUTrainer:
				cout << cross_validate_trainer(reduced2(trainer, nVectors), samples, labels, 3);
				break;
			}
			return true;
		}

	private:
		std::vector<sample_type> samples;
		std::vector<double> labels;
		std::vector<sample_type> testData;
		svm_nu_trainer<kernel_type> trainer;
		svm_c_trainer<kernel_type> c_trainer;
		vector_normalizer<sample_type> normalizer;
		double best_gamma;
		double best_nu;
		double best_c;
		pfunct_type learned_pfunct; 
	protected:
	};
}



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值