adaboost训练——弱分类器训练的opencv源码详解 (1)

转自:http://blog.csdn.net/lanxuecc/article/details/52688605

opencv中adaboost训练弱分类器的主体代码是函数cvCreateCARTClassifier,这个函数通过大致逻辑是:

1、通过调用训练结点函数cvCreateMTStumpClassifier来创建根结点 
2、在要求弱分类器特征不只一个的情况下,通过分裂结点来增加新的特征形成CART树的弱分类器。

源码及注释如下

CV_BOOST_IMPL
CvClassifier* cvCreateCARTClassifier( CvMat* trainData,     //预计算的训练样本每个特征的值矩阵
                                      int flags,            //1表示样本按行排列,0表示样本按行排列
                                      CvMat* trainClasses,  //训练样本类别向量,如果是正样本标识为1,负样本标识为-1
                                      CvMat* typeMask,      //为了便于回调函数而统一格式的变量
                                      CvMat* missedMeasurementsMask,  //同上
                                      CvMat* compIdx,           //特征序列向量
                                      CvMat* sampleIdx,         //样本序列向量
                                      CvMat* weights,           //样本权值向量
                                      CvClassifierTrainParams* trainParams ) //传入一些弱分类器所需的参数比如需要几个特征,和一些需用的分类函数指针 
{
    CvCARTClassifier* cart = NULL;//CART树状弱分类器
    size_t datasize = 0;
    int count = 0;                // CART中的节点数目
    int i = 0;
    int j = 0;

    CvCARTNode* intnode = NULL;  // CART节点  
    CvCARTNode* list = NULL;     // 候选节点链表
    int listcount = 0;           // 候选节点个数
    CvMat* lidx = NULL;          // 左子节点样本序列
    CvMat* ridx = NULL;          // 右子节点样本序列 

    float maxerrdrop = 0.0F;
    int idx = 0;

    //定义节点分裂函数指针  这个函数指针指向的是函数icvSplitIndicesCallback
    void (*splitIdxCallback)( int compidx, float threshold,
                              CvMat* idx, CvMat** left, CvMat** right,
                              void* userdata );
    void* userdata;

    //设置非叶子节点个数  
    count = ((CvCARTTrainParams*) trainParams)->count;  /*弱分类器的特征个数,一般都只有一个*/

    assert( count > 0 );

    /*分配一个弱分类器的内存空间*/
    datasize = sizeof( *cart ) + (sizeof( float ) + 3 * sizeof( int )) * count + 
        sizeof( float ) * (count + 1);

    cart = (CvCARTClassifier*) cvAlloc( datasize );
    memset( cart, 0, datasize );

    /*初始化弱分类器*/
    cart->count = count;

    cart->eval = cvEvalCARTClassifier;  /*弱分类器使用函数*/
    cart->save = NULL;
    cart->release = cvReleaseCARTClassifier;  /*弱分类器内存释放函数 */

    cart->compidx = (int*) (cart + 1);                     //非叶子节点的最优Haar特征序号
    cart->threshold = (float*) (cart->compidx + count);    //非叶子节点的最优Haar特征阈值 
    cart->left  = (int*) (cart->threshold + count);       //左子节点序号,包含叶子节点序号
    cart->right = (int*) (cart->left + count);            //右子节点序号,包含叶子节点序号
    cart->val = (float*) (cart->right + count);           //叶子节点输出置信度数组  

    datasize = sizeof( CvCARTNode ) * (count + count);
    intnode = (CvCARTNode*) cvAlloc( datasize );
    memset( intnode, 0, datasize );
    list = (CvCARTNode*) (intnode + count);

    //节点分裂函数指针,一般为icvSplitIndicesCallback函数 
    splitIdxCallback = ((CvCARTTrainParams*) trainParams)->splitIdx;
    userdata = ((CvCARTTrainParams*) trainParams)->userdata;
    if( splitIdxCallback == NULL )//如果没有用默认的节点分裂函数
    {
        splitIdxCallback = ( CV_IS_ROW_SAMPLE( flags ) )
            ? icvDefaultSplitIdx_R : icvDefaultSplitIdx_C;//R代表样本按行排列,C代表样本按列排列 
        userdata = trainData;
    }

    /* create root of the tree */
    //创建CART弱分类器的根节点,如果该弱分类器只有一个特征,那这里就创建了弱分类器,不用后面作结点分裂 
    //stumpConstructor是一个函数指针,他指向cvCreateMTStumpClassifier函数,所以这里调用的是这个函数
    intnode[0].sampleIdx = sampleIdx;
    intnode[0].stump = (CvStumpClassifier*)
        ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
            trainClasses, typeMask, missedMeasurementsMask, compIdx, sampleIdx, weights,
            ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
    cart->left[0] = cart->right[0] = 0;

    /* build tree */
    //创建树状弱分类器,lerror或者rerror不为0代表着当前节点为非叶子节点 
    listcount = 0;
    for( i = 1; i < count; i++ )/*当弱分类器只有一个特征也就是只一个非叶子结点时,不会走入这个分支*/
    {
        /* split last added node */
        /*这个函数的作用就是:::基于当前结点的阈值将样本分类,
           分类为负样本的样本存储在lidx中,分类为正样本的样本存储在ridx,
           后续从当前结点左分支分裂时,用lidx样本来训练一个结点,
           从当前结点右分支分裂时,用ridx样本来训练一个结点*/
        splitIdxCallback( intnode[i-1].stump->compidx, intnode[i-1].stump->threshold,
            intnode[i-1].sampleIdx, &lidx, &ridx, userdata );

        //为分裂之后的非叶子节点计算最优特征
        if( intnode[i-1].stump->lerror != 0.0F )
        {
            //小于阈值的样本集合,就是当前结点的左分支结点的训练  
            list[listcount].sampleIdx = lidx;

            //基于新样本集合寻找最优特征,重复调用训练桩的函数来训练
            list[listcount].stump = (CvStumpClassifier*)
                ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
                    trainClasses, typeMask, missedMeasurementsMask, compIdx,
                    list[listcount].sampleIdx,
                    weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );

            //计算信息增益(这里是error的下降程度)
            list[listcount].errdrop = intnode[i-1].stump->lerror
                - (list[listcount].stump->lerror + list[listcount].stump->rerror);
            list[listcount].leftflag = 1;
            list[listcount].parent = i-1;
            listcount++;
        }
        else
        {
            cvReleaseMat( &lidx );
        }

        //同上,左分支换成右分支,偏向于右分支 
        if( intnode[i-1].stump->rerror != 0.0F )
        {
            list[listcount].sampleIdx = ridx;
            list[listcount].stump = (CvStumpClassifier*)
                ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
                    trainClasses, typeMask, missedMeasurementsMask, compIdx,
                    list[listcount].sampleIdx,
                    weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
            list[listcount].errdrop = intnode[i-1].stump->rerror
                - (list[listcount].stump->lerror + list[listcount].stump->rerror);
            list[listcount].leftflag = 0;//标识训练出来的节点是当前结点左分支结点还是右还是右分支结点 
            list[listcount].parent = i-1;
            listcount++;
        }
        else
        {
            cvReleaseMat( &ridx );
        }

        if( listcount == 0 ) break;

        /*find the best node to be added to the tree*/
        /*找到已经分裂得到的所有结点中,使分类误差下降最快的那个结点,
                            把它加入到CART树中去,构成弱分类器的一部分*/
        idx = 0;
        maxerrdrop = list[idx].errdrop;
        for( j = 1; j < listcount; j++ )
        {
            if( list[j].errdrop > maxerrdrop )
            {
                idx = j;
                maxerrdrop = list[j].errdrop;
            }
        }

        //确定误差下降最快的结点应该加入到CART树中的位置
        intnode[i] = list[idx];
        if( list[idx].leftflag )
        {
            cart->left[list[idx].parent] = i;
        }
        else
        {
            cart->right[list[idx].parent] = i;
        }
        //将被选中放入CART树的结点删除 
        if( idx != (listcount - 1) )
        {
            list[idx] = list[listcount - 1];
        }
        listcount--;
    }

    /* fill <cart> fields */
    // 这段代码用于确定树中节点最优特征序号、阈值与叶子节点序号和输出置信度  
    // left与right大于等于0,为0代表叶子节点  
    // 就算CART中只有一个节点,仍旧需要设置叶子节点 
    j = 0;
    cart->count = 0;
    for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
    {
        cart->count++;
        cart->compidx[i] = intnode[i].stump->compidx;
        cart->threshold[i] = intnode[i].stump->threshold;

        /* leaves */
        if( cart->left[i] <= 0 )//确定叶子序号与叶子的输出置信度
        {
            cart->left[i] = -j;
            cart->val[j] = intnode[i].stump->left;//这个left是float值,不是CVMat*  
            j++;
        }
        if( cart->right[i] <= 0 )
        {
            cart->right[i] = -j;
            cart->val[j] = intnode[i].stump->right;
            j++;
        }
    }

    /* CLEAN UP *//*一些临时用的内存释放*/
    for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
    {
        intnode[i].stump->release( (CvClassifier**) &(intnode[i].stump) );
        if( i != 0 )
        {
            cvReleaseMat( &(intnode[i].sampleIdx) );
        }
    }
    for( i = 0; i < listcount; i++ )
    {
        list[i].stump->release( (CvClassifier**) &(list[i].stump) );
        cvReleaseMat( &(list[i].sampleIdx) );
    }

    cvFree( &intnode );

    return (CvClassifier*) cart;   /*返回创建的弱分类器*/
}

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值