OpenCV中随机森林的实现与字符识别例子

之前一篇文章简单介绍了随机森林,并且给出来了一些随机森林的资源:http://blog.csdn.net/holybin/article/details/25653597

在opencv中随机森林的实现为CvRTrees类,version2.0及以下版本定义于\OpenCV2.0\include\opencv\ml.h,version2.0以上版本定义于\OpenCV2.4.0\modules\ml\include\opencv2\ml\ml.hpp,实现在\OpenCV2.4.0\modules\ml\src\rtress.cpp,版本不同实现略有不同。

以opencv2.4.0为例子,其定义如下:

class CV_EXPORTS_W CvRTrees : public CvStatModel
{
public:
    CV_WRAP CvRTrees();
    virtual ~CvRTrees();
    virtual bool train( const CvMat* trainData, int tflag,
                        const CvMat* responses, const CvMat* varIdx=0,
                        const CvMat* sampleIdx=0, const CvMat* varType=0,
                        const CvMat* missingDataMask=0,
                        CvRTParams params=CvRTParams() );
    
    virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
    virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
    virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;

#ifndef SWIG
    CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
                       const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
                       const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
                       const cv::Mat& missingDataMask=cv::Mat(),
                       CvRTParams params=CvRTParams() );
    CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
    CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
    CV_WRAP virtual cv::Mat getVarImportance();
#endif
    
    CV_WRAP virtual void clear();

    virtual const CvMat* get_var_importance();
    virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
        const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
    
    virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}

    virtual float get_train_error();    

    virtual void read( CvFileStorage* fs, CvFileNode* node );
    virtual void write( CvFileStorage* fs, const char* name ) const;

    CvMat* get_active_var_mask();
    CvRNG* get_rng();

    int get_tree_count() const;
    CvForestTree* get_tree(int i) const;

protected:
    virtual std::string getName() const;

    virtual bool grow_forest( const CvTermCriteria term_crit );

    // array of the trees of the forest
    CvForestTree** trees;
    CvDTreeTrainData* data;
    int ntrees;
    int nclasses;
    double oob_error;
    CvMat* var_importance;
    int nsamples;

    cv::RNG* rng;
    CvMat* active_var_mask;
};

这里的CvStatModel是OpenCV的机器学习模块(The Machine Learning Library,MLL)的基类,包括KNN,Bayes,SVM等诸多实现都是基于该类,参考opencv的文档:docs.opencv.org/modules/ml/doc/ml.html,附个SVM的使用例子:docs.opencv.org/doc/tutorials/ml/non_linear_svms/non_linear_svms.html#nonlinearsvms。


以下使用CvRTrees类来对字符数据作分类,该例子即opencv附带的例子“\OpenCV2.4.0\samples\cpp\letter_recog.cpp”,字符数据“\OpenCV2.4.0\samples\cpp\letter-recognition.data”来源于UCI,还有一个csv格式的,这个网站还有很多很好的机器学习数据库。

在本例子中,字符数据“letter-recognition.data”有20000个训练字母,每一字母用16维的特征表示:

 1.lettrcapital letter(26 values from A to Z)

 2.x-boxhorizontal position of box(integer)

 3.y-boxvertical position of box(integer)

 4.widthwidth of box(integer)

 5.high height of box(integer)

 6.onpixtotal # on pixels(integer)

 7.x-barmean x of on pixels in box(integer)

 8.y-barmean y of on pixels in box(integer)

 9.x2barmean x variance(integer)

10.y2barmean y variance(integer)

11.xybarmean x y correlation(integer)

12.x2ybrmean of x * x * y(integer)

13.xy2brmean of x * y * y(integer)

14.x-egemean edge count left to right(integer)

15.xegvycorrelation of x-ege with y(integer)

16.y-egemean edge count bottom to top(integer)

17.yegvxcorrelation of y-ege with x(integer)

程序中使用前16000个进行训练,后4000个进行测试。

#include "opencv2/core/core_c.h"
#include "opencv2/ml/ml.hpp"

#include <cstdio>
#include <vector>
/*
	Modified from F:\Program Files\OpenCV2.4.0\samples\cpp\letter_recog.cpp
	Only RF method reserved.
*/

void help()
{
	printf("\nThe sample demonstrates how to train Random Trees classifier\n"
		"(or Boosting classifier, or MLP, or Knearest, or Nbayes, or Support Vector Machines - see main()) using the provided dataset.\n"
		"\n"
		"We use the sample database letter-recognition.data\n"
		"from UCI Repository, here is the link:\n"
		"\n"
		"Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"
		"UCI Repository of machine learning databases\n"
		"[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"
		"Irvine, CA: University of California, Department of Information and Computer Science.\n"
		"\n"
		"The dataset consists of 20000 feature vectors along with the\n"
		"responses - capital latin letters A..Z.\n"
		"The first 16000 (10000 for boosting)) samples are used for training\n"
		"and the remaining 4000 (10000 for boosting) - to test the classifier.\n"
		"======================================================\n");
	printf("\nThis is letter recognition sample.\n"
		"The usage: letter_recog [-data <path to letter-recognition.data>] \\\n"
		"  [-save <output XML file for the classifier>] \\\n"
		"  [-load <XML file with the pre-trained classifier>] \\\n"
		//"  [-boost|-mlp|-knearest|-nbayes|-svm] # to use boost/mlp/knearest/SVM classifier instead of default Random Trees\n" 
		);
}

// This function reads data and responses from the file <filename>
static int read_num_class_data( const char* filename, int var_count, CvMat** data, CvMat** responses )
{
	const int M = 1024;
	FILE* f = fopen( filename, "rt" );
	CvMemStorage* storage;
	CvSeq* seq;
	char buf[M+2];
	float* el_ptr;
	CvSeqReader reader;
	int i, j;

	if( !f )
		return 0;

	el_ptr = new float[var_count+1];
	storage = cvCreateMemStorage();
	seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );

	for(;;)
	{
		char* ptr;
		if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
			break;
		el_ptr[0] = buf[0];
		ptr = buf+2;
		for( i = 1; i <= var_count; i++ )
		{
			int n = 0;
			sscanf( ptr, "%f%n", el_ptr + i, &n );
			ptr += n + 1;
		}
		if( i <= var_count )
			break;
		cvSeqPush( seq, el_ptr );
	}
	fclose(f);

	*data = cvCreateMat( seq->total, var_count, CV_32F );
	*responses = cvCreateMat( seq->total, 1, CV_32F );

	cvStartReadSeq( seq, &reader );

	for( i = 0; i < seq->total; i++ )
	{
		const float* sdata = (float*)reader.ptr + 1;
		float* ddata = data[0]->data.fl + var_count*i;
		float* dr = responses[0]->data.fl + i;

		for( j = 0; j < var_count; j++ )
			ddata[j] = sdata[j];
		*dr = sdata[-1];
		CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
	}

	cvReleaseMemStorage( &storage );
	delete[] el_ptr;
	return 1;
}

static int build_rtrees_classifier( char* data_filename,	char* filename_to_save, char* filename_to_load )
{
	CvMat* data = 0;
	CvMat* responses = 0;
	CvMat* var_type = 0;
	CvMat* sample_idx = 0;

	int ok = read_num_class_data( data_filename, 16, &data, &responses );
	int nsamples_all = 0, ntrain_samples = 0;
	int i = 0;
	double train_hr = 0, test_hr = 0;
	CvRTrees forest;
	CvMat* var_importance = 0;

	if( !ok )
	{
		printf( "Could not read the database %s\n", data_filename );
		return -1;
	}

	printf( "The database %s is loaded.\n", data_filename );
	nsamples_all = data->rows;
	ntrain_samples = (int)(nsamples_all*0.8);

	// Create or load Random Trees classifier
	if( filename_to_load )
	{
		// load classifier from the specified file
		forest.load( filename_to_load );
		ntrain_samples = 0;
		if( forest.get_tree_count() == 0 )
		{
			printf( "Could not read the classifier %s\n", filename_to_load );
			return -1;
		}
		printf( "The classifier %s is loaded.\n", data_filename );
	}
	else
	{
		// create classifier by using <data> and <responses>
		printf( "Training the classifier ...\n");

		// 1. create type mask
		var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
		cvSet( var_type, cvScalarAll(CV_VAR_ORDERED) );
		cvSetReal1D( var_type, data->cols, CV_VAR_CATEGORICAL );

		// 2. create sample_idx
		sample_idx = cvCreateMat( 1, nsamples_all, CV_8UC1 );
		{
			CvMat mat;
			cvGetCols( sample_idx, &mat, 0, ntrain_samples );
			cvSet( &mat, cvRealScalar(1) );

			cvGetCols( sample_idx, &mat, ntrain_samples, nsamples_all );
			cvSetZero( &mat );
		}

		// 3. train classifier
		forest.train( data, CV_ROW_SAMPLE, responses, 0, sample_idx, var_type, 0,
			CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
		printf( "\n");
	}

	// compute prediction error on train and test data
	for( i = 0; i < nsamples_all; i++ )
	{
		double r;
		CvMat sample;
		cvGetRow( data, &sample, i );

		r = forest.predict( &sample );
		r = fabs((double)r - responses->data.fl[i]) <= FLT_EPSILON ? 1 : 0;

		if( i < ntrain_samples )
			train_hr += r;
		else
			test_hr += r;
	}

	test_hr /= (double)(nsamples_all-ntrain_samples);
	train_hr /= (double)ntrain_samples;
	printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
		train_hr*100., test_hr*100. );

	printf( "Number of trees: %d\n", forest.get_tree_count() );

	// Print variable importance
	var_importance = (CvMat*)forest.get_var_importance();
	if( var_importance )
	{
		double rt_imp_sum = cvSum( var_importance ).val[0];
		printf("var#\timportance (in %%):\n");
		for( i = 0; i < var_importance->cols; i++ )
			printf( "%-2d\t%-4.1f\n", i,
			100.f*var_importance->data.fl[i]/rt_imp_sum);
	}

	//Print some proximitites
	printf( "Proximities between some samples corresponding to the letter 'T':\n" );
	{
		CvMat sample1, sample2;
		const int pairs[][2] = {{0,103}, {0,106}, {106,103}, {-1,-1}};

		for( i = 0; pairs[i][0] >= 0; i++ )
		{
			cvGetRow( data, &sample1, pairs[i][0] );
			cvGetRow( data, &sample2, pairs[i][1] );
			printf( "proximity(%d,%d) = %.1f%%\n", pairs[i][0], pairs[i][1],
				forest.get_proximity( &sample1, &sample2 )*100. );
		}
	}

	// Save Random Trees classifier to file if needed
	if( filename_to_save )
		forest.save( filename_to_save );

	cvReleaseMat( &sample_idx );
	cvReleaseMat( &var_type );
	cvReleaseMat( &data );
	cvReleaseMat( &responses );

	return 0;
}

int main( int argc, char *argv[] )
{
	char* filename_to_save = 0;
	char* filename_to_load = 0;
	char default_data_filename[] = "F:\\Program Files\\OpenCV2.4.0\\samples\\cpp\\letter-recognition.data";
	char* data_filename = default_data_filename;
	int method = 0;

	int i;
	for( i = 1; i < argc; i++ )
	{
		if( strcmp(argv[i],"-data") == 0 ) // flag "-data letter_recognition.xml"
		{
			i++;
			data_filename = argv[i];
		}
		else if( strcmp(argv[i],"-save") == 0 ) // flag "-save filename.xml"
		{
			i++;
			filename_to_save = argv[i];
		}
		else if( strcmp(argv[i],"-load") == 0) // flag "-load filename.xml"
		{
			i++;
			filename_to_load = argv[i];
		}
		//else if( strcmp(argv[i],"-boost") == 0)
		//{
		//	method = 1;
		//}
		//else if( strcmp(argv[i],"-mlp") == 0 )
		//{
		//	method = 2;
		//}
		//else if ( strcmp(argv[i], "-knearest") == 0)
		//{
		//	method = 3;
		//}
		//else if ( strcmp(argv[i], "-nbayes") == 0)
		//{
		//	method = 4;
		//}
		//else if ( strcmp(argv[i], "-svm") == 0)
		//{
		//	method = 5;
		//}
		else
			break;
	}

	if( i < argc ||
		(method == 0 ?
		build_rtrees_classifier( data_filename, filename_to_save, filename_to_load ) :
	//method == 1 ?
	//	build_boost_classifier( data_filename, filename_to_save, filename_to_load ) :
	//method == 2 ?
	//	build_mlp_classifier( data_filename, filename_to_save, filename_to_load ) :
	//method == 3 ?
	//	build_knearest_classifier( data_filename, 10 ) :
	//method == 4 ?
	//	build_nbayes_classifier( data_filename) :
	//method == 5 ?
	//	build_svm_classifier( data_filename ):
	-1) < 0)
	{
		help();
	}
	return 0;
}

运行结果:



另参考:使用CvRTrees类对手写体数据作分类


  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值