在主函数中,最耀眼的一句话就是这个了:
- classifier.train( cascadeDirName,
- vecName,
- bgName,
- numPos, numNeg,
- precalcValBufSize, precalcIdxBufSize,
- numStages,
- cascadeParams,
- *featureParams[cascadeParams.featureType],
- stageParams,
- baseFormatSave );
其实现如下:
- bool CvCascadeClassifier::
- train(
- const string _cascadeDirName,
- const string _posFilename,
- const string _negFilename,
- int _numPos, int _numNeg,
- int _precalcValBufSize, int _precalcIdxBufSize,
- int _numStages,
- const CvCascadeParams& _cascadeParams,
- const CvFeatureParams& _featureParams,
- const CvCascadeBoostParams& _stageParams,
- bool baseFormatSave )
- {
- // Start recording clock ticks for training time output
- const clock_t begin_time = clock();
- //确认数据是否有效,略过,下面代码同理,只保留关键语句
- ……
- //判断读入数据级数,并显示
- int startNumStages = (int)stageClassifiers.size();
- if ( startNumStages > 1 )
- cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl;
- else if ( startNumStages == 1)
- cout << endl << "Stage 0 is loaded" << endl;
- //计算要求的叶节点虚警率
- double requiredLeafFARate
- = pow( (double) stageParams->maxFalseAlarm, (double) numStages )
- /(double)stageParams->max_depth;//子树最大深度,默认为1
- double tempLeafFARate;
- for( int i = startNumStages; i < numStages; i++ )
- {
- cout << endl << "===== TRAINING " << i << "-stage =====" << endl;
- cout << "<BEGIN" << endl;
- //无法满足需要的数据,返回
- if ( !updateTrainingSet( requiredLeafFARate, tempLeafFARate ) )
- {
- cout << "Train dataset for temp stage can not be filled. "
- "Branch training terminated." << endl;
- break;
- }
- //叶节点虚警率已经达到要求,返回
- if( tempLeafFARate <= requiredLeafFARate )
- {
- cout << "Required leaf false alarm rate achieved. "
- "Branch training terminated." << endl;
- break;
- }
- //开始训练本级
- CvCascadeBoost* tempStage = new CvCascadeBoost;
- bool isStageTrained = tempStage->train(
- (CvFeatureEvaluator*)featureEvaluator,
- curNumSamples, _precalcValBufSize, _precalcIdxBufSize,
- *((CvCascadeBoostParams*)stageParams) );
- cout << "END>" << endl;
- //本级训练失败,返回
- if(!isStageTrained)
- break;
- //成功,添加本级
- stageClassifiers.push_back( tempStage );
- // save params
- ……
- // Output training time up till now
- ……
- }
- //上面for循环,break出来的
- if(stageClassifiers.size() == 0)
- {
- cout << "Cascade classifier can't be trained."
- " Check the used training parameters." << endl;
- return false;
- }
- //保存级联分类器到xml格式中
- save( dirName + CC_CASCADE_FILENAME, baseFormatSave );
- return true;
- }
- bool CvCascadeClassifier::updateTrainingSet(
- double minimumAcceptanceRatio,
- double & acceptanceRatio)
- {
- int64 posConsumed = 0, negConsumed = 0;
- imgReader.restart();
- //获取正样本
- int posCount = fillPassedSamples( 0, numPos, true, 0, posConsumed );
- if( !posCount )
- return false;
- cout << "POS count : consumed " << posCount << " : "
- << (int)posConsumed << endl;
- //计算需要的负样本 负样本总数乘以 获得的正样本与正样本总数之比,保持了选取训练样本的正负样本比例不变
- int proNumNeg = cvRound(
- (((double)numNeg) * ((double)posCount) ) / numPos
- );
- // apply only a fraction of negative samples.
- //double is required since overflow is possible
- //获取负样本
- int negCount = fillPassedSamples(
- posCount,
- proNumNeg,
- false,
- minimumAcceptanceRatio,
- negConsumed );
- if ( !negCount )
- return false;
- curNumSamples = posCount + negCount;
- //计算acceptanceRatio,也就是FP/(FP+TN)
- acceptanceRatio = negConsumed == 0 ?
- 0 : ( (double)negCount/(double)(int64)negConsumed );
- cout << "NEG count : acceptanceRatio "
- << negCount << " : " << acceptanceRatio << endl;
- return true;
- }
- int CvCascadeClassifier::fillPassedSamples(
- int first,
- int count,
- bool isPositive,
- double minimumAcceptanceRatio,
- int64 &consumed )
- {
- int getcount = 0;
- Mat img(cascadeParams.winSize, CV_8UC1);
- for( int i = first; i < first + count; i++ )
- {
- for( ; ; )
- {
- if( consumed != 0
- && ((double)getcount+1)/(double)(int64)consumed
- <= minimumAcceptanceRatio )
- return getcount;
- //获取对应类别图片
- bool isGetImg = isPositive ? imgReader.getPos( img ) :
- imgReader.getNeg( img );
- if( !isGetImg )
- return getcount;
- consumed++;
- //在数据矩阵中设置图像类别
- featureEvaluator->setImage( img, isPositive ? 1 : 0, i );
- //如果预测为正样本就跳出循环,在填充负样本的过程中,返回的也是误判为正样本的值。
- if( predict( i ) == 1.0F )
- {
- getcount++;
- printf("%s current samples: %d\r",
- isPositive ? "POS":"NEG", getcount);
- break;
- }
- }
- }
- return getcount;
- }
实际的预测过程,值是每个弱分类器预测值的和。
- float CvCascadeBoost::predict( int sampleIdx, bool returnSum ) const
- {
- CV_Assert( weak );
- double sum = 0;
- CvSeqReader reader;
- cvStartReadSeq( weak, &reader );
- cvSetSeqReaderPos( &reader, 0 );
- for( int i = 0; i < weak->total; i++ )
- {
- CvBoostTree* wtree;
- CV_READ_SEQ_ELEM( wtree, reader );
- sum += ((CvCascadeBoostTree*)wtree)->predict(sampleIdx)->value;
- }
- if( !returnSum )
- sum = sum < threshold - CV_THRESHOLD_EPS ? 0.0 : 1.0;
- return (float)sum;
- }
- bool CvCascadeBoost::
- train(
- const CvFeatureEvaluator* _featureEvaluator,
- int _numSamples,
- int _precalcValBufSize, int _precalcIdxBufSize,
- const CvCascadeBoostParams& _params )
- {
- bool isTrained = false;
- CV_Assert( !data );
- clear();
- data = new CvCascadeBoostTrainData(
- _featureEvaluator, _numSamples,
- _precalcValBufSize, _precalcIdxBufSize, _params );
- CvMemStorage *storage = cvCreateMemStorage();
- weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
- storage = 0;
- set_params( _params );
- if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
- data->do_responses_copy();
- //初始化权值
- update_weights( 0 );
- cout << "+----+---------+---------+" << endl;
- cout << "| N | HR | FA |" << endl;
- cout << "+----+---------+---------+" << endl;
- do
- {
- //训练树
- CvCascadeBoostTree* tree = new CvCascadeBoostTree;
- if( !tree->train( data, subsample_mask, this ) )
- {
- delete tree;
- break;
- }
- cvSeqPush( weak, &tree );
- //更新权值
- update_weights( tree );
- trim_weights();
- if( cvCountNonZero(subsample_mask) == 0 )
- break;
- }
- while( !isErrDesired() && (weak->total < params.weak_count) );
- //循环终止条件,虚警率达到要求或者达到最大弱分类器数目
- if(weak->total > 0)
- {
- data->is_classifier = true;
- data->free_train_data();
- isTrained = true;
- }
- else
- clear();
- return isTrained;
- }
这样,一级就训练完了。