OpenCV码源笔记——Decision Tree决策树

本文转自http://blog.csdn.net/sangni007/article/details/7490852

来自OpenCV2.3.1 sample/c/mushroom.cpp

 

1.首先读入agaricus-lepiota.data的训练样本。

   样本中第一项是e或p代表有毒或无毒的标志位;其他是特征,可以把每个样本看做一个特征向量;

   cvSeqPush( seq, el_ptr );读入序列seq中,每一项都存储一个样本即特征向量;

   之后,把特征向量与标志位分别读入CvMat* data与CvMat* reponses中

   还有一个CvMat* missing保留丢失位当前小于0位置;

 

2.训练样本

  1. dtree = new CvDTree;  
  2. dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,  
  3. CvDTreeParams( 8, // max depth  
  4. 10, // min sample count 样本数小于10时,停止分裂    
  5. 0, // regression accuracy: N/A here;回归树的限制精度  
  6. true// compute surrogate split, as we have missing data;;为真时,计算missing data和变量的重要性  
  7. 15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义  
  8. 10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds  
  9. true// use 1SE rule => smaller tree;If true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确  
  10. true// throw away the pruned tree branches  
  11. priors //错分类的代价我们判断的:有毒VS无毒 错误的代价比 the array of priors, the bigger p_weight, the more attention  
  12. // to the poisonous mushrooms   
  13. // (a mushroom will be judjed to be poisonous with bigger chance)  
  14. ));  
    dtree = new CvDTree;
    dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
    CvDTreeParams( 8, // max depth
    10, // min sample count 样本数小于10时,停止分裂 
    0, // regression accuracy: N/A here;回归树的限制精度
    true, // compute surrogate split, as we have missing data;;为真时,计算missing data和变量的重要性
    15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义
    10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds
    true, // use 1SE rule => smaller tree;If true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确
    true, // throw away the pruned tree branches
    priors //错分类的代价我们判断的:有毒VS无毒 错误的代价比 the array of priors, the bigger p_weight, the more attention
    // to the poisonous mushrooms
    // (a mushroom will be judjed to be poisonous with bigger chance)
    ));



 

3.

  1. double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;  
double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;


4.interactive_classification通过人工输入特征来判断。

 

  1. #include "opencv2/core/core_c.h"  
  2. #include "opencv2/ml/ml.hpp"   
  3. #include <stdio.h>   
  4.   
  5. void help()  
  6. {  
  7.     printf("\nThis program demonstrated the use of OpenCV's decision tree function for learning and predicting data\n"  
  8.         "Usage :\n"  
  9.         "./mushroom <path to agaricus-lepiota.data>\n"  
  10.         "\n"  
  11.         "The sample demonstrates how to build a decision tree for classifying mushrooms.\n"  
  12.         "It uses the sample base agaricus-lepiota.data from UCI Repository, here is the link:\n"  
  13.         "\n"  
  14.         "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"  
  15.         "UCI Repository of machine learning databases\n"  
  16.         "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"  
  17.         "Irvine, CA: University of California, Department of Information and Computer Science.\n"  
  18.         "\n"  
  19.         "// loads the mushroom database, which is a text file, containing\n"  
  20.         "// one training sample per row, all the input variables and the output variable are categorical,\n"  
  21.         "// the values are encoded by characters.\n\n");  
  22. }  
  23.   
  24. int mushroom_read_database( const char* filename, CvMat** data, CvMat** missing, CvMat** responses )  
  25. {  
  26.     const int M = 1024;  
  27.     FILE* f = fopen( filename, "rt" );  
  28.     CvMemStorage* storage;  
  29.     CvSeq* seq;  
  30.     char buf[M+2], *ptr;  
  31.     float* el_ptr;  
  32.     CvSeqReader reader;  
  33.     int i, j, var_count = 0;  
  34.   
  35.     if( !f )  
  36.         return 0;  
  37.   
  38.     // read the first line and determine the number of variables  
  39.     if( !fgets( buf, M, f ))  
  40.     {  
  41.         fclose(f);  
  42.         return 0;  
  43.     }  
  44.   
  45.     for( ptr = buf; *ptr != '\0'; ptr++ )  
  46.         var_count += *ptr == ',';//计算每个样本的数量,每个样本一个“,”,样本数量=var_count+1;  
  47.     assert( ptr - buf == (var_count+1)*2 );  
  48.   
  49.     // create temporary memory storage to store the whole database  
  50.     //把样本存入seq中,存储空间是storage;   
  51.     el_ptr = new float[var_count+1];  
  52.     storage = cvCreateMemStorage();  
  53.     seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );//  
  54.   
  55.     for(;;)  
  56.     {  
  57.         for( i = 0; i <= var_count; i++ )  
  58.         {  
  59.             int c = buf[i*2];  
  60.             el_ptr[i] = c == '?' ? -1.f : (float)c;  
  61.         }  
  62.         if( i != var_count+1 )  
  63.             break;  
  64.         cvSeqPush( seq, el_ptr );  
  65.         if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )  
  66.             break;  
  67.     }  
  68.     fclose(f);  
  69.   
  70.     // allocate the output matrices and copy the base there  
  71.     *data = cvCreateMat( seq->total, var_count, CV_32F );//行数:样本数量;列数:样本大小;  
  72.     *missing = cvCreateMat( seq->total, var_count, CV_8U );  
  73.     *responses = cvCreateMat( seq->total, 1, CV_32F );//样本标志;  
  74.   
  75.     cvStartReadSeq( seq, &reader );  
  76.   
  77.     for( i = 0; i < seq->total; i++ )  
  78.     {  
  79.         const float* sdata = (float*)reader.ptr + 1;  
  80.         float* ddata = data[0]->data.fl + var_count*i;  
  81.         float* dr = responses[0]->data.fl + i;  
  82.         uchar* dm = missing[0]->data.ptr + var_count*i;  
  83.   
  84.         for( j = 0; j < var_count; j++ )  
  85.         {  
  86.             ddata[j] = sdata[j];  
  87.             dm[j] = sdata[j] < 0;  
  88.         }  
  89.         *dr = sdata[-1];//样本的第一个位置是标志;  
  90.         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );  
  91.     }  
  92.   
  93.     cvReleaseMemStorage( &storage );  
  94.     delete el_ptr;  
  95.     return 1;  
  96. }  
  97.   
  98.   
  99. CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,  
  100.     const CvMat* responses, float p_weight )  
  101. {  
  102.     CvDTree* dtree;  
  103.     CvMat* var_type;  
  104.     int i, hr1 = 0, hr2 = 0, p_total = 0;  
  105.     float priors[] = { 1, p_weight };  
  106.   
  107.     var_type = cvCreateMat( data->cols + 1, 1, CV_8U );  
  108.     cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categorical  
  109.   
  110.     dtree = new CvDTree;  
  111.   
  112.     dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,  
  113.         CvDTreeParams( 8, // max depth  
  114.         10, // min sample count样本数小于10时,停止分裂  
  115.         0, // regression accuracy: N/A here;回归树的限制精度  
  116.         true// compute surrogate split, as we have missing data;为真时,计算missing data和可变的重要性正确度  
  117.         15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义  
  118.         10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation   
  119.         true// use 1SE rule => smaller treeIf true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确  
  120.         true// throw away the pruned tree branches  
  121.         priors // the array of priors, the bigger p_weight, the more attention  
  122.         // to the poisonous mushrooms   
  123.         // (a mushroom will be judjed to be poisonous with bigger chance)  
  124.         ));  
  125.   
  126.     // compute hit-rate on the training database, demonstrates predict usage.  
  127.       
  128.     for( i = 0; i < data->rows; i++ )  
  129.     {  
  130.         CvMat sample, mask;  
  131.         cvGetRow( data, &sample, i );  
  132.         cvGetRow( missing, &mask, i );  
  133.         double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;  
  134.         int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;//大于阈值FLT_EPSILON被判断为误检  
  135.         if( d )  
  136.         {  
  137.             if( r != 'p' )  
  138.                 hr1++;  
  139.             else  
  140.                 hr2++;  
  141.         }  
  142.         p_total += responses->data.fl[i] == 'p';  
  143.     }  
  144.   
  145.     printf( "Results on the training database:\n"  
  146.         "\tPoisonous mushrooms mis-predicted: %d (%g%%)\n"  
  147.         "\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,  
  148.         hr2, (double)hr2*100/(data->rows - p_total) );  
  149.   
  150.     cvReleaseMat( &var_type );  
  151.   
  152.     return dtree;  
  153. }  
  154.   
  155.   
  156. static const char* var_desc[] =  
  157. {  
  158.     "cap shape (bell=b,conical=c,convex=x,flat=f)",  
  159.     "cap surface (fibrous=f,grooves=g,scaly=y,smooth=s)",  
  160.     "cap color (brown=n,buff=b,cinnamon=c,gray=g,green=r,\n\tpink=p,purple=u,red=e,white=w,yellow=y)",  
  161.     "bruises? (bruises=t,no=f)",  
  162.     "odor (almond=a,anise=l,creosote=c,fishy=y,foul=f,\n\tmusty=m,none=n,pungent=p,spicy=s)",  
  163.     "gill attachment (attached=a,descending=d,free=f,notched=n)",  
  164.     "gill spacing (close=c,crowded=w,distant=d)",  
  165.     "gill size (broad=b,narrow=n)",  
  166.     "gill color (black=k,brown=n,buff=b,chocolate=h,gray=g,\n\tgreen=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y)",  
  167.     "stalk shape (enlarging=e,tapering=t)",  
  168.     "stalk root (bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r)",  
  169.     "stalk surface above ring (ibrous=f,scaly=y,silky=k,smooth=s)",  
  170.     "stalk surface below ring (ibrous=f,scaly=y,silky=k,smooth=s)",  
  171.     "stalk color above ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",  
  172.     "stalk color below ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",  
  173.     "veil type (partial=p,universal=u)",  
  174.     "veil color (brown=n,orange=o,white=w,yellow=y)",  
  175.     "ring number (none=n,one=o,two=t)",  
  176.     "ring type (cobwebby=c,evanescent=e,flaring=f,large=l,\n\tnone=n,pendant=p,sheathing=s,zone=z)",  
  177.     "spore print color (black=k,brown=n,buff=b,chocolate=h,green=r,\n\torange=o,purple=u,white=w,yellow=y)",  
  178.     "population (abundant=a,clustered=c,numerous=n,\n\tscattered=s,several=v,solitary=y)",  
  179.     "habitat (grasses=g,leaves=l,meadows=m,paths=p\n\turban=u,waste=w,woods=d)",  
  180.     0  
  181. };  
  182.   
  183.   
  184. void print_variable_importance( CvDTree* dtree, const char** var_desc )  
  185. {  
  186.     const CvMat* var_importance = dtree->get_var_importance();  
  187.     int i;  
  188.     char input[1000];  
  189.   
  190.     if( !var_importance )  
  191.     {  
  192.         printf( "Error: Variable importance can not be retrieved\n" );  
  193.         return;  
  194.     }  
  195.   
  196.     printf( "Print variable importance information? (y/n) " );  
  197.     scanf( "%1s", input );  
  198.     if( input[0] != 'y' && input[0] != 'Y' )  
  199.         return;  
  200.   
  201.     for( i = 0; i < var_importance->cols*var_importance->rows; i++ )  
  202.     {  
  203.         double val = var_importance->data.db[i];  
  204.         if( var_desc )  
  205.         {  
  206.             char buf[100];  
  207.             int len = strchr( var_desc[i], '(' ) - var_desc[i] - 1;  
  208.             strncpy( buf, var_desc[i], len );  
  209.             buf[len] = '\0';  
  210.             printf( "%s", buf );  
  211.         }  
  212.         else  
  213.             printf( "var #%d", i );  
  214.         printf( ": %g%%\n", val*100. );  
  215.     }  
  216. }  
  217.   
  218. void interactive_classification( CvDTree* dtree, const char** var_desc )  
  219. {  
  220.     char input[1000];  
  221.     const CvDTreeNode* root;  
  222.     CvDTreeTrainData* data;  
  223.   
  224.     if( !dtree )  
  225.         return;  
  226.   
  227.     root = dtree->get_root();  
  228.     data = dtree->get_data();  
  229.   
  230.     for(;;)  
  231.     {  
  232.         const CvDTreeNode* node;  
  233.   
  234.         printf( "Start/Proceed with interactive mushroom classification (y/n): " );  
  235.         scanf( "%1s", input );  
  236.         if( input[0] != 'y' && input[0] != 'Y' )  
  237.             break;  
  238.         printf( "Enter 1-letter answers, '?' for missing/unknown value...\n" );   
  239.   
  240.         // custom version of predict   
  241.         //传统的预测方式;   
  242.         node = root;  
  243.         for(;;)  
  244.         {  
  245.             CvDTreeSplit* split = node->split;  
  246.             int dir = 0;  
  247.   
  248.             if( !node->left || node->Tn <= dtree->get_pruned_tree_idx() || !node->split )  
  249.                 break;  
  250.   
  251.             for( ; split != 0; )  
  252.             {  
  253.                 int vi = split->var_idx, j;  
  254.                 int count = data->cat_count->data.i[vi];  
  255.                 const int* map = data->cat_map->data.i + data->cat_ofs->data.i[vi];  
  256.   
  257.                 printf( "%s: ", var_desc[vi] );  
  258.                 scanf( "%1s", input );  
  259.   
  260.                 if( input[0] == '?' )  
  261.                 {  
  262.                     split = split->next;  
  263.                     continue;  
  264.                 }  
  265.   
  266.                 // convert the input character to the normalized value of the variable  
  267.                 for( j = 0; j < count; j++ )  
  268.                     if( map[j] == input[0] )  
  269.                         break;  
  270.                 if( j < count )  
  271.                 {  
  272.                     dir = (split->subset[j>>5] & (1 << (j&31))) ? -1 : 1;  
  273.                     if( split->inversed )  
  274.                         dir = -dir;  
  275.                     break;  
  276.                 }  
  277.                 else  
  278.                     printf( "Error: unrecognized value\n" );  
  279.             }  
  280.   
  281.             if( !dir )  
  282.             {  
  283.                 printf( "Impossible to classify the sample\n");  
  284.                 node = 0;  
  285.                 break;  
  286.             }  
  287.             node = dir < 0 ? node->left : node->right;  
  288.         }  
  289.   
  290.         if( node )  
  291.             printf( "Prediction result: the mushroom is %s\n",  
  292.             node->class_idx == 0 ? "EDIBLE" : "POISONOUS" );  
  293.         printf( "\n-----------------------------\n" );  
  294.     }  
  295. }  
  296.   
  297.   
  298. int main( int argc, char** argv )  
  299. {  
  300.     CvMat *data = 0, *missing = 0, *responses = 0;  
  301.     CvDTree* dtree;  
  302.     const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";  
  303.   
  304.     help();  
  305.   
  306.     if( !mushroom_read_database( base_path, &data, &missing, &responses ) )  
  307.     {  
  308.         printf( "\nUnable to load the training database\n\n");  
  309.         help();  
  310.         return -1;  
  311.     }  
  312.   
  313.     dtree = mushroom_create_dtree( data, missing, responses,  
  314.         10 // poisonous mushrooms will have 10x higher weight in the decision tree  
  315.         );  
  316.     cvReleaseMat( &data );  
  317.     cvReleaseMat( &missing );  
  318.     cvReleaseMat( &responses );  
  319.   
  320.     print_variable_importance( dtree, var_desc );  
  321.     interactive_classification( dtree, var_desc );  
  322.     delete dtree;  
  323.   
  324.     return 0;  
  325. }  

 

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值