opencv中adaboost训练算法分析

0、概述  

opencv集成了经典adaboost算法,并结合haar特征实现了人脸检测功能。算法原理可参考人脸检测大牛Paul Viola 的文章《Rapid Object Detection using a Boosted Cascade of Simple Feature》。由于该算法堪称经典,并可推广应用于其他相关检测识别领域(如车牌检测、车辆检测识别),因此有必要从源码上学习其实现过程。人脸检测说是检测,实际上关键算法体现在训练(training)模块,openCV2.4.4 内含haartraining代码,本篇博文及后续相关博文都基于training算法为说明对象。

1 预备知识

1. 什么是haar特征?
haar特征是众多图像特征中的一种,计算方法为将特征窗口内的像素相加或相减。典型特征窗口如下图所示;
这里写图片描述
2. 一副图像上haar特征为什么有很多,具体有多少个?
由于haar特征窗口可有很多种类别,如上图中的(A、B、C、D不限于此四类)类,每一种类又可以变化特征窗口的尺度,在类别和尺度都确定的基础上,haar特征窗口还可以在样本图像上平移滑动。因此,一副图像上可生成众多haar特征。
具体个数请参考其他博文。haar特征及个数计算
3. 为什么要计算haar特征?
每一个haar特征可视为一个弱分类器(后面会解释),训练过程就是选择弱分类器(haar特征)的过程。
4. 什么是积分图,为什么用积分图?
积分图可加快haar特征的计算。
5. 什么是adaboost算法,如何实现?
集成学习算法包括两大类算法:bagging 算法和boosting 算法。
adaboost属于集成学习算法boosting的一种实现,在《Rapid Object Detection using a Boosted Cascade of Simple Features》作者论文中给出了adaboost算法流程图:
这里写图片描述

需要说明的是,该流程与opencv实现的方法有些出入,如openCV将负样本标记为-1,而文
献中标记为0,因此强分类器的阈值判断也不同。
openCV实现的adaboost,更接近于此片博文adaboost算法原理推导
这里写图片描述
6. 什么是决策树,为什么要用决策树?
决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。Entropy = 系统的凌乱程度,使用算法ID3, C4.5和C5.0生成树算法使用熵。这一度量是基于信息学理论中熵的概念。(盗用百度百科)
这里写图片描述
为啥要知道决策树? 因为opencV利用CART树,实现的弱分类器训练。

2 分类器训练

  为了表述的严谨性,本博文涉及的几个常用名词定义如下:
  (1)弱分类器:在本博文就是指CART,当CART只有一个分裂节点时,CART退化为Stump。CART的每个分裂节点都由1个haar特征。
  (2)强分类器:在本文下就是指Stage。

2.1 分类器训练概览

知道了haar特征,知道了adaboost,那么openCV到底是如何实现训练分类器的呢? 整体框架如下图所示。算法采用Cascade Tree结构,cascade Tree内部由多个stage构成,stage内部又由多个CART构成。opencv 实现分类器训练的过程就是建立cacade Tree的过程。
这里写图片描述

分类器最基本的训练单元为CART,然后是生成stage,最后建立cascade,以下详细展开分析。
源码调用关系图如下:
这里写图片描述

2.2 CART创建

  CART创建的关键是搜索并获得有效的haar特征。所涉及的函数及调用关系如下:
  

CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData, //训练样本
                      int flags,   //行、列
                      CvMat* trainClasses, //类别标识,正样本+1,负样本-1
                      CvMat* /*typeMask*/,
                      CvMat* missedMeasurementsMask,
                      CvMat* compIdx,
                      CvMat* sampleIdx,
                      CvMat* weights,  // 样本权重
                      CvClassifierTrainParams* trainParams //训练参数
                         //stumperror分裂规则:"misclass", "gini", "entropy"
                       )

cvCreateMTStumpClassifier 函数是创建CART的主函数,此过程是训练最耗时的一部分。此函数在内部调用findStumpThreshold_16s[stumperror],搜索弱分类器的阈值。

CvFindThresholdFunc findStumpThreshold_16s[4] = {
        icvFindStumpThreshold_misc_16s,
        icvFindStumpThreshold_gini_16s,
        icvFindStumpThreshold_entropy_16s,
        icvFindStumpThreshold_sq_16s
    };

以上为函数指针,每个函数指针都是以宏实现。以熵衰减规则分裂节点函数(icvFindStumpThreshold_gini_16s)为例,定义如下

/* entropy error
 * err = - wpos * log(wpos / (wpos + wneg)) - wneg * log(wneg / (wpos + wneg))
 */
#define ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( suffix, type )                             \
    ICV_DEF_FIND_STUMP_THRESHOLD( entropy_##suffix, type,                                \
        wposl = 0.5F * ( wl + wyl ); //左分支点中正样本权重和                                                    \
        wposr = 0.5F * ( wr + wyr ); //右分支点中正样本权重和                                                    \
        curleft = 0.5F * ( 1.0F + curleft ); //左分支中正样本权重和/左分支正负样本权重总和(左分支中正样本所占的比例,超过0.5则分类为正)                                           \
        curright = 0.5F * ( 1.0F + curright ); // 右分支正样本权重和/右分支正负样本权重总和(右分支中正样本所占比例,超过0.5则分类为负)                                         \
        curlerror = currerror = 0.0F;                                                    \
        if( curleft > CV_ENTROPY_THRESHOLD )  // 左分支熵                                           \
            curlerror -= wposl * logf( curleft );                                        \
        if( curleft < 1.0F - CV_ENTROPY_THRESHOLD )                                     \
            curlerror -= (wl - wposl) * logf( 1.0F - curleft );                          \
                                                                                         \
        if( curright > CV_ENTROPY_THRESHOLD )   //右分支熵                                          \
            currerror -= wposr * logf( curright );                                       \
        if( curright < 1.0F - CV_ENTROPY_THRESHOLD )                                     \
            currerror -= (wr - wposr) * logf( 1.0F - curright );                         \
    )

寻找弱分类器阈值函数

#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                              \
        uchar* data, size_t datastep,                                                    \
        uchar* wdata, size_t wstep,                                                      \
        uchar* ydata, size_t ystep,                                                      \
        uchar* idxdata, size_t idxstep, int num,                                         \
        float* lerror,                                                                   \
        float* rerror,                                                                   \
        float* threshold, float* left, float* right,                                     \
        float* sumw, float* sumwy, float* sumwyy )                                       \
{                                                                                        \
    int found = 0;                                                                       \
    float wyl  = 0.0F;  //左分支,各样本权重乘以类别y并求和                                                                 \
    float wl   = 0.0F;   // 左分支各样本权重求和                                                               \
    float wyyl = 0.0F;   // 左分支,各样本类别的平方乘以权重后求和                                                                \
    float wyr  = 0.0F;   //右分支,各样本权重乘以类别y并求和                                                                \
    float wr   = 0.0F;   // 右分支各样本权重求和                                                                  \
                                                                                         \
    float curleft  = 0.0F;   //左分支,正样本权重和/总的权重和                                                            \
    float curright = 0.0F;   //右分支,正样本权重和/总的权重和                                                            \
    float* prevval = NULL;                                                               \
    float* curval  = NULL;                                                               \
    float curlerror = 0.0F;    //左分支,的熵                                                          \
    float currerror = 0.0F;    //右分支,的熵                                                           \
    float wposl;   // 分配到左分支的正样本权重和                                                                     \
    float wposr;    //分配到右分支的正样本权重和                                                                     \
                                                                                         \
    int i = 0;                                                                           \
    int idx = 0;                                                                         \
                                                                                         \
    wposl = wposr = 0.0F;                                                                \
    if( *sumw == FLT_MAX )                                                               \
    {                                                                                    \
        /* calculate sums */                                                             \
        float *y = NULL;                                                                 \
        float *w = NULL;                                                                 \
        float wy = 0.0F;                                                                 \
                                                                                         \
        *sumw   = 0.0F;                                                                  \
        *sumwy  = 0.0F;                                                                  \
        *sumwyy = 0.0F;                                                                  \
        for( i = 0; i < num; i++ )                                                       \
        {                                                                                \
            idx = (int) ( *((type*) (idxdata + i*idxstep)) );                            \
            w = (float*) (wdata + idx * wstep);                                          \
            *sumw += *w;      //权重和                                                            \
            y = (float*) (ydata + idx * ystep);                                          \
            wy = (*w) * (*y);                                                            \
            *sumwy += wy;   //类别权重和                                                             \
            *sumwyy += wy * (*y);  //当y=+1或-1时,此值同 sumw                                                      \
        }                                                                                \
    }                                                                                    \
                                                                                         \
    for( i = 0; i < num; i++ )                                                           \
    {                                                                                    \
        idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                \
        curval = (float*) (data + idx * datastep);                                       \
         /* for debug purpose */                                                         \
        if( i > 0 ) assert( (*prevval) <= (*curval) );                                   \
                                                                                         \
        wyr  = *sumwy - wyl;                                                             \
        wr   = *sumw  - wl;                                                              \
                                                                                         \
        if( wl > 0.0 ) curleft = wyl / wl;                                               \
        else curleft = 0.0F;                                                             \
                                                                                         \
        if( wr > 0.0 ) curright = wyr / wr;                                              \
        else curright = 0.0F;                                                            \
                                                                                         \
        error                                                                            \
                                                                                         \    
        if( curlerror + currerror < (*lerror) + (*rerror) )                              \         查找到使熵最小的阈值点   
        {                                                                                \
            (*lerror) = curlerror;  //左分支熵                                                     \
            (*rerror) = currerror;  //右分支熵                                                     \
            *threshold = *curval;    //阈值                                                    \
            if( i > 0 ) {                                                                \
                *threshold = 0.5F * (*threshold + *prevval);                             \
            }                                                                            \
            *left  = curleft;  // 左分支中,正样本(权重)占左分支总权重的比例                                                          \
            *right = curright; //右分支,正样本(权重)占右分支总权重的比例                                                          \
            found = 1;                                                                   \
        }                                                                                \
                                                                                         \
        do           //计算左右权重和                                                                    \
        {                                                                                \
            wl  += *((float*) (wdata + idx * wstep));                                    \
            wyl += (*((float*) (wdata + idx * wstep)))                                   \
                * (*((float*) (ydata + idx * ystep)));                                   \
            wyyl += *((float*) (wdata + idx * wstep))                                    \
                * (*((float*) (ydata + idx * ystep)))                                    \
                * (*((float*) (ydata + idx * ystep)));                                   \
        }                                                                                \
        while( (++i) < num &&                                                            \
            ( *((float*) (data + (idx =                                                  \
                (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                 \
                == *curval ) );                                                          \
        --i;                                                                             \
        prevval = curval;                                                                \
    } /* for each value */                                                               \
                                                                                         \
    return found;                                                                        \
}

2.3 Stage 生成

某个CART弱分类器创建完成,然后根据adaboost权重更新规则,计算此弱分类器的 αi ,在DAB(Discrete AdaBoost)模式下,左右分支节点输出便是预测值与 αi 的乘积。
  Stage内弱分类器的个数,由maxfalsealarm确定。如果一直达不到小于maxfalsealarm的要求,则需要训练更多的弱分类器放入Stage中。
  Stage的 阈值由minhitrate确定。多个弱分类器的联合(即强分类器)对正样本的预测概率不能小于minhitrate。

2.4 几个问题说明

1,为什么训练过程中经常出现卡死,无法进入下一个stage的情况?
假设参数输入时候的负样本个数为1000; 则每一级stage都需要从待选的所有负样本中获得1000个被上一级错分为正样本的负样本(体现为虚警率)。由于越往后虚警率越低,同样是获取1000个错分样本,后面需要遍历的负样本范围越来越大。当最终无法得到错分样本时,程序便进入死循环中。

static
int icvGetHaarTrainingData( CvHaarTrainingData* data, int first, int count,
                            CvIntHaarClassifier* cascade,
                            CvGetHaarTrainingDataCallback callback, void* userdata,
                            int* consumed, double* acceptance_ratio )
{
    int i = 0;
    ccounter_t getcount = 0;
    ccounter_t thread_getcount = 0;
    ccounter_t consumed_count; 
    ccounter_t thread_consumed_count;

    /* private variables */
    CvMat img;
    CvMat sum;
    CvMat tilted;
    CvMat sqsum;

    sum_type* sumdata;
    sum_type* tilteddata;
    float*    normfactor;

    /* end private variables */

    assert( data != NULL );
    assert( first + count <= data->maxnum );
    assert( cascade != NULL );
    assert( callback != NULL );

    // if( !cvbgdata ) return 0; this check needs to be done in the callback for BG

    CCOUNTER_SET_ZERO(getcount);
    CCOUNTER_SET_ZERO(thread_getcount);
    CCOUNTER_SET_ZERO(consumed_count);
    CCOUNTER_SET_ZERO(thread_consumed_count);

    #ifdef CV_OPENMP
    #pragma omp parallel private(img, sum, tilted, sqsum, sumdata, tilteddata, \
                                 normfactor, thread_consumed_count, thread_getcount)
    #endif /* CV_OPENMP */
    {
        sumdata    = NULL;
        tilteddata = NULL;
        normfactor = NULL;

        CCOUNTER_SET_ZERO(thread_getcount);
        CCOUNTER_SET_ZERO(thread_consumed_count);
        int ok = 1;

        img = cvMat( data->winsize.height, data->winsize.width, CV_8UC1,
            cvAlloc( sizeof( uchar ) * data->winsize.height * data->winsize.width ) );
        sum = cvMat( data->winsize.height + 1, data->winsize.width + 1,
                     CV_SUM_MAT_TYPE, NULL );
        tilted = cvMat( data->winsize.height + 1, data->winsize.width + 1,
                        CV_SUM_MAT_TYPE, NULL );
        sqsum = cvMat( data->winsize.height + 1, data->winsize.width + 1, CV_SQSUM_MAT_TYPE,
                       cvAlloc( sizeof( sqsum_type ) * (data->winsize.height + 1)
                                                     * (data->winsize.width + 1) ) );

        #ifdef CV_OPENMP
        #pragma omp for schedule(static, 1)
        #endif /* CV_OPENMP */
        for( i = first; (i < first + count); i++ )
        {
            if( !ok )
                continue;
            for( ; ; ) //当没有合适的负样本时,陷入死循环
            {
                ok = callback( &img, userdata );
                if( !ok )
                    break;

                CCOUNTER_INC(thread_consumed_count);

                sumdata = (sum_type*) (data->sum.data.ptr + i * data->sum.step);
                tilteddata = (sum_type*) (data->tilted.data.ptr + i * data->tilted.step);
                normfactor = data->normfactor.data.fl + i;
                sum.data.ptr = (uchar*) sumdata;
                tilted.data.ptr = (uchar*) tilteddata;
                icvGetAuxImages( &img, &sum, &tilted, &sqsum, normfactor );            
                if( cascade->eval( cascade, sumdata, tilteddata, *normfactor ) != 0.0F )//读取正样本时,有可能小于正样本个数;
                                                                                        //读取负样本时,反应的是错分为正样本的个数
                {
                    CCOUNTER_INC(thread_getcount);
                    break;
                }
            }

#ifdef CV_VERBOSE
            if( (i - first) % 500 == 0 )
            {
                fprintf( stderr, "%3d%%\r", (int) ( 100.0 * (i - first) / count ) );
                fflush( stderr );
            }
#endif /* CV_VERBOSE */
        }

        cvFree( &(img.data.ptr) );
        cvFree( &(sqsum.data.ptr) );

        #ifdef CV_OPENMP
        #pragma omp critical (c_consumed_count)
        #endif /* CV_OPENMP */
        {
            /* consumed_count += thread_consumed_count; */
            CCOUNTER_ADD(getcount, thread_getcount);
            CCOUNTER_ADD(consumed_count, thread_consumed_count);
        }
    } /* omp parallel */

    if( consumed != NULL )
    {
        *consumed = (int)consumed_count;
    }

    if( acceptance_ratio != NULL )
    {
        /* *acceptance_ratio = ((double) count) / consumed_count; */
        *acceptance_ratio = CCOUNTER_DIV(count, consumed_count); // 计算虚警率
    }

    return static_cast<int>(getcount);
}

2,负样本是怎么截取的?
将输入图像等比例缩小到最小尺寸(不小于样本尺寸),然后再逐渐放大至原始尺寸。期间,用样本同样大小(width,height)的窗口滑动(步长为width/2,height/2 ),截取负样本。

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值