train_cascade 源码阅读之级联训练

在主函数中,最耀眼的一句话就是这个了:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. classifier.train( cascadeDirName,  
  2.                   vecName,  
  3.                   bgName,  
  4.                   numPos, numNeg,  
  5.                   precalcValBufSize, precalcIdxBufSize,  
  6.                   numStages,  
  7.                   cascadeParams,  
  8.                   *featureParams[cascadeParams.featureType],  
  9.         stageParams,  
  10.         baseFormatSave );  

其实现如下:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. bool CvCascadeClassifier::  
  2. train(  
  3.         const string _cascadeDirName,  
  4.         const string _posFilename,  
  5.         const string _negFilename,  
  6.         int _numPos, int _numNeg,  
  7.         int _precalcValBufSize, int _precalcIdxBufSize,  
  8.         int _numStages,  
  9.         const CvCascadeParams& _cascadeParams,  
  10.         const CvFeatureParams& _featureParams,  
  11.         const CvCascadeBoostParams& _stageParams,  
  12.         bool baseFormatSave )  
  13. {  
  14.     // Start recording clock ticks for training time output  
  15.     const clock_t begin_time = clock();  
  16.   
  17.     //确认数据是否有效,略过,下面代码同理,只保留关键语句  
  18.     ……  
  19.   
  20.     //判断读入数据级数,并显示  
  21.     int startNumStages = (int)stageClassifiers.size();  
  22.     if ( startNumStages > 1 )  
  23.         cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl;  
  24.     else if ( startNumStages == 1)  
  25.         cout << endl << "Stage 0 is loaded" << endl;  
  26.       
  27.     //计算要求的叶节点虚警率  
  28.     double requiredLeafFARate  
  29.             = pow( (double) stageParams->maxFalseAlarm, (double) numStages )   
  30.                    /(double)stageParams->max_depth;//子树最大深度,默认为1  
  31.     double tempLeafFARate;  
  32.   
  33.     forint i = startNumStages; i < numStages; i++ )  
  34.     {  
  35.         cout << endl << "===== TRAINING " << i << "-stage =====" << endl;  
  36.         cout << "<BEGIN" << endl;  
  37.   
  38.         //无法满足需要的数据,返回  
  39.         if ( !updateTrainingSet( requiredLeafFARate, tempLeafFARate ) )  
  40.         {  
  41.             cout << "Train dataset for temp stage can not be filled. "  
  42.                     "Branch training terminated." << endl;  
  43.             break;  
  44.         }  
  45.         //叶节点虚警率已经达到要求,返回  
  46.         if( tempLeafFARate <= requiredLeafFARate )  
  47.         {  
  48.             cout << "Required leaf false alarm rate achieved. "  
  49.                     "Branch training terminated." << endl;  
  50.             break;  
  51.         }  
  52.         //开始训练本级  
  53.         CvCascadeBoost* tempStage = new CvCascadeBoost;  
  54.         bool isStageTrained = tempStage->train(  
  55.                     (CvFeatureEvaluator*)featureEvaluator,  
  56.                     curNumSamples, _precalcValBufSize, _precalcIdxBufSize,  
  57.                     *((CvCascadeBoostParams*)stageParams) );  
  58.         cout << "END>" << endl;  
  59.         //本级训练失败,返回  
  60.         if(!isStageTrained)  
  61.             break;  
  62.         //成功,添加本级  
  63.         stageClassifiers.push_back( tempStage );  
  64.   
  65.         // save params  
  66.         ……  
  67.   
  68.         // Output training time up till now  
  69.         ……  
  70.     }  
  71.     //上面for循环,break出来的  
  72.     if(stageClassifiers.size() == 0)  
  73.     {  
  74.         cout << "Cascade classifier can't be trained."  
  75.                 " Check the used training parameters." << endl;  
  76.         return false;  
  77.     }  
  78.   
  79.     //保存级联分类器到xml格式中  
  80.     save( dirName + CC_CASCADE_FILENAME, baseFormatSave );  
  81.     return true;  
  82. }  
接着看updateTrainingSet,每一级操作前先更新样本数据

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. bool CvCascadeClassifier::updateTrainingSet(  
  2.         double  minimumAcceptanceRatio,  
  3.         double  & acceptanceRatio)  
  4. {  
  5.     int64 posConsumed = 0, negConsumed = 0;  
  6.     imgReader.restart();  
  7.     //获取正样本  
  8.     int posCount = fillPassedSamples( 0, numPos, true, 0, posConsumed );  
  9.     if( !posCount )  
  10.         return false;  
  11.     cout << "POS count : consumed   " << posCount << " : "  
  12.          << (int)posConsumed << endl;  
  13.     //计算需要的负样本 负样本总数乘以 获得的正样本与正样本总数之比,保持了选取训练样本的正负样本比例不变  
  14.     int proNumNeg = cvRound(  
  15.                 (((double)numNeg) * ((double)posCount) ) / numPos  
  16.                 );  
  17.     // apply only a fraction of negative samples.  
  18.     //double is required since overflow is possible  
  19.     //获取负样本  
  20.     int negCount = fillPassedSamples(  
  21.                 posCount,  
  22.                 proNumNeg,  
  23.                 false,  
  24.                 minimumAcceptanceRatio,  
  25.                 negConsumed );  
  26.     if ( !negCount )  
  27.         return false;  
  28.   
  29.     curNumSamples = posCount + negCount;  
  30.     //计算acceptanceRatio,也就是FP/(FP+TN)  
  31.     acceptanceRatio = negConsumed == 0 ?  
  32.                 0 : ( (double)negCount/(double)(int64)negConsumed );  
  33.     cout << "NEG count : acceptanceRatio    "  
  34.          << negCount << " : " << acceptanceRatio << endl;  
  35.     return true;  
  36. }  
上段代码中最重要的就是fillPassedSamples函数了。

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. int CvCascadeClassifier::fillPassedSamples(  
  2.         int     first,  
  3.         int     count,  
  4.         bool    isPositive,  
  5.         double  minimumAcceptanceRatio,  
  6.         int64   &consumed )  
  7. {  
  8.     int getcount = 0;  
  9.     Mat img(cascadeParams.winSize, CV_8UC1);  
  10.     forint i = first; i < first + count; i++ )  
  11.     {  
  12.         for( ; ; )  
  13.         {  
  14.             if( consumed != 0  
  15.                     && ((double)getcount+1)/(double)(int64)consumed  
  16.                     <= minimumAcceptanceRatio )  
  17.                 return getcount;  
  18.             //获取对应类别图片  
  19.             bool isGetImg = isPositive ? imgReader.getPos( img ) :  
  20.                                          imgReader.getNeg( img );  
  21.             if( !isGetImg )  
  22.                 return getcount;  
  23.             consumed++;  
  24.             //在数据矩阵中设置图像类别  
  25.             featureEvaluator->setImage( img, isPositive ? 1 : 0, i );  
  26.             //如果预测为正样本就跳出循环,在填充负样本的过程中,返回的也是误判为正样本的值。  
  27.             if( predict( i ) == 1.0F )  
  28.             {  
  29.                 getcount++;  
  30.                 printf("%s current samples: %d\r",  
  31.                        isPositive ? "POS":"NEG", getcount);  
  32.                 break;  
  33.             }  
  34.         }  
  35.     }  
  36.     return getcount;  
  37. }  

实际的预测过程,值是每个弱分类器预测值的和。

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. float CvCascadeBoost::predict( int sampleIdx, bool returnSum ) const  
  2. {  
  3.     CV_Assert( weak );  
  4.     double sum = 0;  
  5.     CvSeqReader reader;  
  6.     cvStartReadSeq( weak, &reader );  
  7.     cvSetSeqReaderPos( &reader, 0 );  
  8.     forint i = 0; i < weak->total; i++ )  
  9.     {  
  10.         CvBoostTree* wtree;  
  11.         CV_READ_SEQ_ELEM( wtree, reader );  
  12.         sum += ((CvCascadeBoostTree*)wtree)->predict(sampleIdx)->value;  
  13.     }  
  14.     if( !returnSum )  
  15.         sum = sum < threshold - CV_THRESHOLD_EPS ? 0.0 : 1.0;  
  16.     return (float)sum;  
  17. }  
到了训练部分:

[cpp]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. bool CvCascadeBoost::  
  2. train(  
  3.         const CvFeatureEvaluator* _featureEvaluator,  
  4.         int _numSamples,  
  5.         int _precalcValBufSize, int _precalcIdxBufSize,  
  6.         const CvCascadeBoostParams& _params )  
  7. {  
  8.     bool isTrained = false;  
  9.     CV_Assert( !data );  
  10.     clear();  
  11.     data = new CvCascadeBoostTrainData(  
  12.                 _featureEvaluator, _numSamples,  
  13.                 _precalcValBufSize, _precalcIdxBufSize, _params );  
  14.     CvMemStorage *storage = cvCreateMemStorage();  
  15.     weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );  
  16.     storage = 0;  
  17.   
  18.     set_params( _params );  
  19.     if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )  
  20.         data->do_responses_copy();  
  21.     //初始化权值  
  22.     update_weights( 0 );  
  23.   
  24.     cout << "+----+---------+---------+" << endl;  
  25.     cout << "|  N |    HR   |    FA   |" << endl;  
  26.     cout << "+----+---------+---------+" << endl;  
  27.   
  28.     do  
  29.     {  
  30.         //训练树  
  31.         CvCascadeBoostTree* tree = new CvCascadeBoostTree;  
  32.         if( !tree->train( data, subsample_mask, this ) )  
  33.         {  
  34.             delete tree;  
  35.             break;  
  36.         }  
  37.         cvSeqPush( weak, &tree );  
  38.         //更新权值  
  39.         update_weights( tree );  
  40.         trim_weights();  
  41.         if( cvCountNonZero(subsample_mask) == 0 )  
  42.             break;  
  43.     }  
  44.     while( !isErrDesired() && (weak->total < params.weak_count) );  
  45.     //循环终止条件,虚警率达到要求或者达到最大弱分类器数目  
  46.   
  47.     if(weak->total > 0)  
  48.     {  
  49.         data->is_classifier = true;  
  50.         data->free_train_data();  
  51.         isTrained = true;  
  52.     }  
  53.     else  
  54.         clear();  
  55.   
  56.     return isTrained;  
  57. }  

这样,一级就训练完了。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值