OpenCV码源笔记——(letter_recog.cpp)Random Forest part

Refer from http://blog.csdn.net/yangtrees/article/details/7498618

其中重要函数的参数解释:http://blog.csdn.net/sangni007/article/details/7488727

  1. read_num_class_data( const char* filename, int var_count,  
  2.                      CvMat** data, CvMat** responses )  
  3. {  
  4.     const int M = 1024;  
  5.     FILE* f = fopen( filename, "rt" );  
  6.     CvMemStorage* storage;  
  7.     CvSeq* seq;  
  8.     char buf[M+2];  
  9.     float* el_ptr;  
  10.     CvSeqReader reader;  
  11.     int i, j;  
  12.   
  13.     if( !f )  
  14.         return 0;  
  15.   
  16.     el_ptr = new float[var_count+1];  
  17.     storage = cvCreateMemStorage();  
  18.     seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );  
  19.   
  20.     for(;;)  
  21.     {  
  22.         char* ptr;  
  23.         if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )  
  24.             break;  
  25.         el_ptr[0] = buf[0];  
  26.         ptr = buf+2;  
  27.         for( i = 1; i <= var_count; i++ )  
  28.         {  
  29.             int n = 0;  
  30.             sscanf( ptr, "%f%n", el_ptr + i, &n );  
  31.             ptr += n + 1;  
  32.         }  
  33.         if( i <= var_count )  
  34.             break;  
  35.         cvSeqPush( seq, el_ptr );  
  36.     }  
  37.     fclose(f);  
  38.   
  39.     *data = cvCreateMat( seq->total, var_count, CV_32F );  
  40.     *responses = cvCreateMat( seq->total, 1, CV_32F );  
  41.   
  42.     cvStartReadSeq( seq, &reader );  
  43.   
  44.     for( i = 0; i < seq->total; i++ )  
  45.     {  
  46.         const float* sdata = (float*)reader.ptr + 1;  
  47.         float* ddata = data[0]->data.fl + var_count*i;  
  48.         float* dr = responses[0]->data.fl + i;  
  49.   
  50.         for( j = 0; j < var_count; j++ )  
  51.             ddata[j] = sdata[j];  
  52.         *dr = sdata[-1];  
  53.         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );  
  54.     }  
  55.   
  56.     cvReleaseMemStorage( &storage );  
  57.     delete el_ptr;  
  58.     return 1;  
  59. }  
  60.   
  61. static  
  62. int build_rtrees_classifier( char* data_filename,  
  63.     char* filename_to_save, char* filename_to_load )  
  64. {  
  65.     CvMat* data = 0;  
  66.     CvMat* responses = 0;  
  67.     CvMat* var_type = 0;  
  68.     CvMat* sample_idx = 0;  
  69.   
  70.     int ok = read_num_class_data( data_filename, 16, &data, &responses );  
  71.     int nsamples_all = 0, ntrain_samples = 0;  
  72.     int i = 0;  
  73.     double train_hr = 0, test_hr = 0;  
  74.     CvRTrees forest;  
  75.     CvMat* var_importance = 0;  
  76.   
  77.     if( !ok )  
  78.     {  
  79.         printf( "Could not read the database %s\n", data_filename );  
  80.         return -1;  
  81.     }  
  82.   
  83.     printf( "The database %s is loaded.\n", data_filename );  
  84.     nsamples_all = data->rows;  
  85.     ntrain_samples = (int)(nsamples_all*0.8);  
  86.   
  87.     // Create or load Random Trees classifier  
  88.     if( filename_to_load )  
  89.     {  
  90.         // load classifier from the specified file  
  91.         forest.load( filename_to_load );  
  92.         ntrain_samples = 0;  
  93.         if( forest.get_tree_count() == 0 )  
  94.         {  
  95.             printf( "Could not read the classifier %s\n", filename_to_load );  
  96.             return -1;  
  97.         }  
  98.         printf( "The classifier %s is loaded.\n", data_filename );  
  99.     }  
  100.     else  
  101.     {  
  102.         // create classifier by using <data> and <responses>  
  103.         printf( "Training the classifier ...\n");  
  104.   
  105.         // 1. create type mask  
  106.         var_type = cvCreateMat( data->cols + 1, 1, CV_8U );//response的类型;  
  107.         cvSet( var_type, cvScalarAll(CV_VAR_ORDERED) );  
  108.         cvSetReal1D( var_type, data->cols, CV_VAR_CATEGORICAL );  
  109.   
  110.         // 2. create sample_idx  
  111.         sample_idx = cvCreateMat( 1, nsamples_all, CV_8UC1 );  
  112.         {  
  113.             CvMat mat;  
  114.             cvGetCols( sample_idx, &mat, 0, ntrain_samples );  
  115.             cvSet( &mat, cvRealScalar(1) );  
  116.   
  117.             cvGetCols( sample_idx, &mat, ntrain_samples, nsamples_all );  
  118.             cvSetZero( &mat );  
  119.         }  
  120.   
  121.         // 3. train classifier  
  122.         forest.train( data, CV_ROW_SAMPLE, responses, 0, sample_idx, var_type, 0,  
  123.             CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));  
  124.         printf( "\n");  
  125.     }  
  126.   
  127.     // compute prediction error on train and test data  
  128.     for( i = 0; i < nsamples_all; i++ )  
  129.     {  
  130.         double r;  
  131.         CvMat sample;  
  132.         cvGetRow( data, &sample, i );  
  133.   
  134.         r = forest.predict( &sample );  
  135.         r = fabs((double)r - responses->data.fl[i]) <= FLT_EPSILON ? 1 : 0;  
  136.   
  137.         if( i < ntrain_samples )  
  138.             train_hr += r;  
  139.         else  
  140.             test_hr += r;  
  141.     }  
  142.   
  143.     test_hr /= (double)(nsamples_all-ntrain_samples);  
  144.     train_hr /= (double)ntrain_samples;  
  145.     printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",  
  146.             train_hr*100., test_hr*100. );  
  147.   
  148.     printf( "Number of trees: %d\n", forest.get_tree_count() );  
  149.   
  150.     // Print variable importance 打印自变量重要性;  
  151.     var_importance = (CvMat*)forest.get_var_importance();  
  152.     if( var_importance )  
  153.     {  
  154.         double rt_imp_sum = cvSum( var_importance ).val[0];  
  155.         printf("var#\timportance (in %%):\n");  
  156.         for( i = 0; i < var_importance->cols; i++ )  
  157.             printf( "%-2d\t%-4.1f\n", i,  
  158.             100.f*var_importance->data.fl[i]/rt_imp_sum);  
  159.     }  
  160.   
  161.     //Print some proximitites 打印相似度;  
  162.     printf( "Proximities between some samples corresponding to the letter 'T':\n" );  
  163.     {  
  164.         CvMat sample1, sample2;  
  165.         const int pairs[][2] = {{0,103}, {0,106}, {106,103}, {-1,-1}};  
  166.   
  167.         for( i = 0; pairs[i][0] >= 0; i++ )  
  168.         {  
  169.             cvGetRow( data, &sample1, pairs[i][0] );  
  170.             cvGetRow( data, &sample2, pairs[i][1] );  
  171.             printf( "proximity(%d,%d) = %.1f%%\n", pairs[i][0], pairs[i][1],  
  172.                 forest.get_proximity( &sample1, &sample2 )*100. );  
  173.         }  
  174.     }  
  175.   
  176.     // Save Random Trees classifier to file if needed  
  177.     //if( filename_to_save )  
  178.         forest.save( "forest.xml" );  
  179.   
  180.     cvReleaseMat( &sample_idx );  
  181.     cvReleaseMat( &var_type );  
  182.     cvReleaseMat( &data );  
  183.     cvReleaseMat( &responses );  
  184.   
  185.     return 0;  
  186. }  


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值