train_cascade 源码阅读之级联训练

本文深入探讨了OpenCV中train_cascade的源码,特别是级联分类器的决策树训练过程。核心在于理解每个弱分类器预测值如何通过累加形式影响最终的分类决策。
摘要由CSDN通过智能技术生成

在主函数中,最耀眼的一句话就是这个了:

    classifier.train( cascadeDirName,
                      vecName,
                      bgName,
                      numPos, numNeg,
                      precalcValBufSize, precalcIdxBufSize,
                      numStages,
                      cascadeParams,
                      *featureParams[cascadeParams.featureType],
            stageParams,
            baseFormatSave );

其实现如下:

bool CvCascadeClassifier::
train(
        const string _cascadeDirName,
        const string _posFilename,
        const string _negFilename,
        int _numPos, int _numNeg,
        int _precalcValBufSize, int _precalcIdxBufSize,
        int _numStages,
        const CvCascadeParams& _cascadeParams,
        const CvFeatureParams& _featureParams,
        const CvCascadeBoostParams& _stageParams,
        bool baseFormatSave )
{
    // Start recording clock ticks for training time output
    const clock_t begin_time = clock();

    //确认数据是否有效,略过,下面代码同理,只保留关键语句
    ……

    //判断读入数据级数,并显示
    int startNumStages = (int)stageClassifiers.size();
    if ( startNumStages > 1 )
        cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl;
    else if ( startNumStages == 1)
        cout << endl << "Stage 0 is loaded" << endl;
    
    //计算要求的叶节点虚警率
    double requiredLeafFARate
            = pow( (double) stageParams->maxFalseAlarm, (double) numStages ) 
                   /(double)stageParams->max_depth;//子树最大深度,默认为1
    double tempLeafFARate;

    for( int i = startNumStages; i < numStages; i++ )
    {
        cout << endl << "===== TRAINING " << i << "-stage =====" << endl;
        cout << "<BEGIN" << endl;

        //无法满足需要的数据,返回
        if ( !updateTrainingSet( requiredLeafFARate, tempLeafFARate ) )
        {
            cout << "Train dataset for temp stage can not be filled. "
                    "Branch training terminated." << endl;
            break;
        }
        //叶节点虚警率已经达到要求,返回
        if( tempLeafFARate <= requiredLeafFARate )
        {
            cout << "Required leaf false alarm rate achieved. "
                    "Branch training terminated." << endl;
            break;
        }
        //开始训练本级
        CvCascadeBoost* tempStage = new CvCascadeBoost;
        bool isStageTrained = tempStage->train(
                    (CvFeatureEvaluator*)featureEvaluator,
                    curNumSamples, _precalcValBufSize, _precalcIdxBufSize,
                    *((CvCascadeBoostParams*)stageParams) );
        cout << "END>" << endl;
        //本级训练失败,返回
        if(!isStageTrained)
            break;
        //成功,添加本级
        stageClassifiers.push_back( tempStage );

        // save params
        ……

        // Output training time up till now
        ……
    }
    //上面for循环,break出来的
    if(stageClassifiers.size() == 0)
    {
        cout << "Cascade classifier can't be trained."
                " Check the used training parameters." << endl;
        return false;
    }

    //保存级联分类器到xml格式中
    save( dirName + CC_CASCADE_FILENAME, baseFormatSave );
    return true;
}
接着看updateTrainingSet,每一级操作前先更新样本数据

bool CvCascadeClassifier::updateTrainingSet(
        double  minimumAcceptanceRatio,
        double  & acceptanceRatio)
{
    int64 posConsumed = 0, negConsumed = 0;
    imgReader.restart();
    //获取正样本
    int posCount = fillPassedSamples( 0, numPos, true, 0, posConsumed );
    if( !posCount )
        return false;
    cout << "POS count : consumed   " << posCount << " : "
         << (int)posConsumed << endl;
    //计算需要的负样本 负样本总数乘以 获得的正样本与正样本总数之比,保持了选取训练样本的正负样本比例不变
    int proNumNeg = cvRound(
                (((double)numNeg) * ((double)posCount) ) / numPos
                );
    // apply only a fraction of negative samples.
    //double is required since overflow is possible
    //获取负样本
    int negCount = fillPassedSamples(
                posCount,
                proNumNeg,
                false,
                minimumAcceptanceRatio,
                negConsumed );
    if ( !negCount )
        return false;

    curNumSamples = posCount + negCount;
    //计算acceptanceRatio,也就是FP/(FP+TN)
    acceptanceRatio = negConsumed == 0 ?
                0 : ( (double)negCount/(double)(int64)negConsumed );
    cout << "NEG count : acceptanceRatio    "
         << negCount << " : " << acceptanceRatio << endl;
    return true;
}
上段代码中最重要的就是fillPassedSamples函数了。

int CvCascadeClassifier::fillPassedSamples(
        int     first,
        int     count,
        bool    isPositive,
        double  minimumAcceptanceRatio,
        int64   &consumed )
{
    int getcount = 0;
    Mat img(cascadeParams.winSize, CV_8UC1);
    for( int i = first; i < first + count; i++ )
    {
        for( ; ; )
        {
            if( consumed != 0
                    && ((double)getcount+1)/(double)(int64)consumed
                    <= minimumAcceptanceRatio )
                return getcount;
            //获取对应类别图片
            bool isGetImg = isPositive ? imgReader.getPos( img ) :
                                         imgReader.getNeg( img );
            if( !isGetImg )
                return getcount;
            consumed++;
            //在数据矩阵中设置图像类别
            featureEvaluator->setImage( img, isPositive ? 1 : 0, i );
            //如果预测为正样本就跳出循环,在填充负样本的过程中,返回的也是误判为正样本的值。
            if( predict( i ) == 1.0F )
            {
                getcount++;
                printf("%s current samples: %d\r",
                       isPositive ? "POS":"NEG", getcount);
                break;
            }
        }
    }
    return getcount;
}

实际的预测过程,值是每个弱分类器预测值的和。

float CvCascadeBoost::predict( int sampleIdx, bool returnSum ) const
{
    CV_Assert( weak );
    double sum = 0;
    CvSeqReader reader;
    cvStartReadSeq( weak, &reader );
    cvSetSeqReaderPos( &reader, 0 );
    for( int i = 0; i < weak->total; i++ )
    {
        CvBoostTree* wtree;
        CV_READ_SEQ_ELEM( wtree, reader );
        sum += ((CvCascadeBoostTree*)wtree)->predict(sampleIdx)->value;
    }
    if( !returnSum )
        sum = sum < threshold - CV_THRESHOLD_EPS ? 0.0 : 1.0;
    return (float)sum;
}
到了训练部分:

bool CvCascadeBoost::
train(
        const CvFeatureEvaluator* _featureEvaluator,
        int _numSamples,
        int _precalcValBufSize, int _precalcIdxBufSize,
        const CvCascadeBoostParams& _params )
{
    bool isTrained = false;
    CV_Assert( !data );
    clear();
    data = new CvCascadeBoostTrainData(
                _featureEvaluator, _numSamples,
                _precalcValBufSize, _precalcIdxBufSize, _params );
    CvMemStorage *storage = cvCreateMemStorage();
    weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
    storage = 0;

    set_params( _params );
    if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
        data->do_responses_copy();
    //初始化权值
    update_weights( 0 );

    cout << "+----+---------+---------+" << endl;
    cout << "|  N |    HR   |    FA   |" << endl;
    cout << "+----+---------+---------+" << endl;

    do
    {
        //训练树
        CvCascadeBoostTree* tree = new CvCascadeBoostTree;
        if( !tree->train( data, subsample_mask, this ) )
        {
            delete tree;
            break;
        }
        cvSeqPush( weak, &tree );
        //更新权值
        update_weights( tree );
        trim_weights();
        if( cvCountNonZero(subsample_mask) == 0 )
            break;
    }
    while( !isErrDesired() && (weak->total < params.weak_count) );
    //循环终止条件,虚警率达到要求或者达到最大弱分类器数目

    if(weak->total > 0)
    {
        data->is_classifier = true;
        data->free_train_data();
        isTrained = true;
    }
    else
        clear();

    return isTrained;
}

这样,一级就训练完了。



评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值