adaboost训练——弱分类器训练的opencv源码详解 (2)

转自:http://blog.csdn.net/lanxuecc/article/details/52689008

接着上一个博客http://blog.csdn.net/lanxuecc/article/details/52688605在弱分类器训练的主体函数cvCreateCARTClassifier中我们看到主要是调用cvCreateMTStumpClassifier函数来训练得到弱分类器的结点,下面注释下这个函数

/*
 * cvCreateMTStumpClassifier
 *
 * Multithreaded stump classifier constructor
 * Includes huge train data support through callback function
 */
CV_BOOST_IMPL
CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData,              //训练样本HAAR特征值矩阵
                                         int flags,                     // 1.按行排列,0.按列排列 
                                         CvMat* trainClasses,           // 样本类别{-1,1}
                                         CvMat* /*typeMask*/,           // 为了便于回调函数统一格式
                                         CvMat* missedMeasurementsMask, // 未知,很少用到
                                         CvMat* compIdx,                // 特征序列(必须为NULL)(行向量)
                                         CvMat* sampleIdx,              // 实际训练样本序列(行向量)
                                         CvMat* weights,                // 实际训练样本样本权重(行向量)
                                         CvClassifierTrainParams* trainParams ) //这个结构体中指明一些参数和数据,比如分类误差计算方法,特征总数以及多线程运行时每个线程处理的特征数
{
    CvStumpClassifier* stump = NULL;        // 弱分类器(桩)
    int m = 0;                              // 样本总数
    int n = 0;                              // 所有特征个数   
    uchar* data = NULL;                     // trainData数据指针
    size_t cstep   = 0;                     // trainData一行字节数
    size_t sstep   = 0;                     // trainData元素字节数
    int    datan   = 0;                     // 预计算特征个数
    uchar* ydata = NULL;                    // trainClasses数据指针
    size_t ystep = 0;                       // trainClasses元素字节数
    uchar* idxdata = NULL;                  // sampleIdx数据指针
    size_t idxstep = 0;                     // sampleIdx单个元素字节数
    int    l = 0;                           // 实际训练样本个数    
    uchar* wdata = NULL;                    // weights数据指针
    size_t wstep = 0;                       // weights元素字节数

    /*sortedIdx为事先计算好的特征值-样本矩阵,包含有预计算的所有HAAR特征对应于所有样本的特征值(按大小排列) */
    uchar* sorteddata = NULL;               // sortedIdx数据指针
    int    sortedtype    = 0;               // sortedIdx元素类型
    size_t sortedcstep   = 0;               // sortedIdx一行字节数
    size_t sortedsstep   = 0;               // sortedIdx元素字节数
    int    sortedn       = 0;               // sortedIdx行数(预计算特征个数)
    int    sortedm       = 0;               // sortedIdx列数(实际训练样本个数)

    char* filter = NULL;                    // 样本存在标示(行向量),如果样本存在则为1,否则为0
    int i = 0;

    int compidx = 0;                        // 每组特征的起始序号
    int stumperror;                         // 计算阈值方法:1.misclass 2.gini 3.entropy 4.least sum of squares
    int portion;                            // 每组特征个数,对所有特征n进行分组处理,每组portion个

    /* private variables */
    CvMat mat;                              // 补充特征-样本矩阵
    CvValArray va;
    float lerror;                           // 阈值左侧误差
    float rerror;                           // 阈值右侧误差
    float left;                             // 置信度(左分支)
    float right;                            // 置信度(右分支)
    float threshold;                        // 阈值
    int optcompidx;                         // 最优特征

    float sumw;                             
    float sumwy;
    float sumwyy;

    /*临时变量,循环用*/
    int t_compidx;
    int t_n;

    int ti;
    int tj;
    int tk;

    uchar* t_data;                          // 指向data
    size_t t_cstep;                         // cstep
    size_t t_sstep;                         // sstep

    size_t matcstep;                        // mat一行字节数
    size_t matsstep;                        // mat元素字节数

    int* t_idx;                             // 样本序列
    /* end private variables */

    CV_Assert( trainParams != NULL );
    CV_Assert( trainClasses != NULL );
    CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
    CV_Assert( missedMeasurementsMask == NULL );
    CV_Assert( compIdx == NULL );

    // 计算阈值方法:1.misclass 2.gini 3.entropy 4.least sum of squares
    stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;

    //样本类别
    ydata = trainClasses->data.ptr;
    if( trainClasses->rows == 1 )
    {
        m = trainClasses->cols;
        ystep = CV_ELEM_SIZE( trainClasses->type );
    }
    else
    {
        m = trainClasses->rows;
        ystep = trainClasses->step;
    }

    //样本权重
    wdata = weights->data.ptr;
    if( weights->rows == 1 )
    {
        CV_Assert( weights->cols == m );
        wstep = CV_ELEM_SIZE( weights->type );
    }
    else
    {
        CV_Assert( weights->rows == m );
        wstep = weights->step;
    }

    //事先计算好的排序好的所有样本的所有特征值排序好的序号
    //sortedIdx为空,trainData为行向量(1*m);sortedIdx不为空,trainData为矩阵(m*datan);
    if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
    {
        sortedtype =
            CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
        assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
                || sortedtype == CV_32FC1 );
        sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
        sortedsstep = CV_ELEM_SIZE( sortedtype );
        sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
        sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
        sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
    }

    //事先计算好的排序好的所有样本的所有特征值
    if( trainData == NULL )                         //为空的情况没有遇到
    {
        assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
        n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
        assert( n > 0 );
    }
    else
    {
        assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
        data = trainData->data.ptr;
        if( CV_IS_ROW_SAMPLE( flags ) )             //trainData为矩阵
        {
            cstep = CV_ELEM_SIZE( trainData->type );
            sstep = trainData->step;
            assert( m == trainData->rows );
            datan = n = trainData->cols;
        }
        else                                        //trainData为向量
        {
            sstep = CV_ELEM_SIZE( trainData->type );
            cstep = trainData->step;
            assert( m == trainData->cols );
            datan = n = trainData->rows;
        }

        // trainData为矩阵,当trainData为向量时,datan = n = 1
        if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
        {
            n = ((CvMTStumpTrainParams*) trainParams)->numcomp;     // 总特征个数  
        }        
    }

    //预计算特征个数一定要小于特征总数
    assert( datan <= n );

    if( sampleIdx != NULL )     // 已经剔除小权值样本
    {
        assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
        idxdata = sampleIdx->data.ptr;
        idxstep = ( sampleIdx->rows == 1 )
            ? CV_ELEM_SIZE( sampleIdx->type ) : sampleIdx->step;
        l = ( sampleIdx->rows == 1 ) ? sampleIdx->cols : sampleIdx->rows;

        // sorteddata中存放的是所有训练样本,需要筛选出实际训练样本
        if( sorteddata != NULL )
        {
            filter = (char*) cvAlloc( sizeof( char ) * m );
            memset( (void*) filter, 0, sizeof( char ) * m );
            for( i = 0; i < l; i++ )
            {
                filter[(int) *((float*) (idxdata + i * idxstep))] = (char) 1;   // 存在则为1,不存在则为0
            }
        }
    }
    else                        // 未剔除小权值样本
    {
        l = m;
    }

    //桩,分配一个结点的内存空间,用来存储
    stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
    memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );

    //每组特征个数,个从理解是为多线程计算,为提高性能将所有特征分成很多组
    portion = ((CvMTStumpTrainParams*)trainParams)->portion;

    if( portion < 1 )
    {
        /* auto portion */
        portion = n;
    #ifdef _OPENMP
        portion /= omp_get_max_threads();        
    #endif /* _OPENMP */        
    }

    stump->eval = cvEvalStumpClassifier;
    stump->tune = NULL;
    stump->save = NULL;
    stump->release = cvReleaseStumpClassifier;

    stump->lerror = FLT_MAX;
    stump->rerror = FLT_MAX;
    stump->left  = 0.0F;
    stump->right = 0.0F;

    compidx = 0;

    // 并行计算,默认为关闭的
#ifdef _OPENMP
#pragma omp parallel private(mat, va, lerror, rerror, left, right, threshold, \
                                 optcompidx, sumw, sumwy, sumwyy, t_compidx, t_n, \
                                 ti, tj, tk, t_data, t_cstep, t_sstep, matcstep,  \
                                 matsstep, t_idx)
#endif /* _OPENMP */
    {
        lerror = FLT_MAX;
        rerror = FLT_MAX;
        left  = 0.0F;
        right = 0.0F;
        threshold = 0.0F;
        optcompidx = 0;

        sumw   = FLT_MAX;
        sumwy  = FLT_MAX;
        sumwyy = FLT_MAX;

        t_compidx = 0;
        t_n = 0;

        ti = 0;
        tj = 0;
        tk = 0;

        t_data = NULL;
        t_cstep = 0;
        t_sstep = 0;

        matcstep = 0;
        matsstep = 0;

        t_idx = NULL;

        mat.data.ptr = NULL;

        // 预计算特征个数小于特征总数,则说明存在新特征,用于计算样本的新特征,存放在mat中
        if( datan < n )
        {
            if( CV_IS_ROW_SAMPLE( flags ) )
            {
                mat = cvMat( m, portion, CV_32FC1, 0 );
                matcstep = CV_ELEM_SIZE( mat.type );
                matsstep = mat.step;
            }
            else
            {
                mat = cvMat( portion, m, CV_32FC1, 0 );
                matcstep = mat.step;
                matsstep = CV_ELEM_SIZE( mat.type );
            }
            mat.data.ptr = (uchar*) cvAlloc( sizeof( float ) * mat.rows * mat.cols );
        }

        // 将实际训练样本序列存放进t_idx
        if( filter != NULL || sortedn < n )
        {
            t_idx = (int*) cvAlloc( sizeof( int ) * m );
            if( sortedn == 0 || filter == NULL )
            {
                if( idxdata != NULL )
                {
                    for( ti = 0; ti < l; ti++ )
                    {
                        t_idx[ti] = (int) *((float*) (idxdata + ti * idxstep));
                    }
                }
                else
                {
                    for( ti = 0; ti < l; ti++ )
                    {
                        t_idx[ti] = ti;
                    }
                }                
            }
        }

    #ifdef _OPENMP
    #pragma omp critical(c_compidx)
    #endif /* _OPENMP */

        // 初始化计算特征范围
        {
            t_compidx = compidx;
            compidx += portion;
        }

        // 寻找最优弱分类器
        while( t_compidx < n )
        {
            t_n = portion;                      // 每组特征个数
            if( t_compidx < datan )             // 已经计算过的特征
            {
                t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
                t_data = data;
                t_cstep = cstep;
                t_sstep = sstep;
            }
            else                                // 新特征
            {
                t_n = ( t_n < (n - t_compidx) ) ? t_n : (n - t_compidx);
                t_cstep = matcstep;
                t_sstep = matsstep;
                t_data = mat.data.ptr - t_compidx * ((size_t) t_cstep );

                // 计算每个新特征对应于每个训练样本的特征值
                ((CvMTStumpTrainParams*)trainParams)->getTrainData( &mat,
                        sampleIdx, compIdx, t_compidx, t_n,
                        ((CvMTStumpTrainParams*)trainParams)->userdata );
            }

            /* 预计算特征部分,直接寻找最优特征,也就是传说中的最优弱分类器 */
            if( sorteddata != NULL )
            {
                if( filter != NULL )    //需要提取实际训练样本
                {
                    switch( sortedtype )
                    {
                        case CV_16SC1:  // 这里重复度很高,只注释一个分支,剩下的都一个道理

                            // 从一组特征(datan个预计算特征)中寻找最优特征
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                tk = 0;

                                // 提取实际训练样本
                                for( tj = 0; tj < sortedm; tj++ )
                                {
                                    int curidx = (int) ( *((short*) (sorteddata
                                            + ti * sortedcstep + tj * sortedsstep)) );
                                    if( filter[curidx] != 0 )
                                    {
                                        t_idx[tk++] = curidx;
                                    }
                                }

                                // 如果findStumpThreshold_32s返回值为1, 则更新最优特征
                                if( findStumpThreshold_32s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        (uchar*) t_idx, sizeof( int ), tk,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        case CV_32SC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                tk = 0;
                                for( tj = 0; tj < sortedm; tj++ )
                                {
                                    int curidx = (int) ( *((int*) (sorteddata
                                            + ti * sortedcstep + tj * sortedsstep)) );
                                    if( filter[curidx] != 0 )
                                    {
                                        t_idx[tk++] = curidx;
                                    }
                                }
                                if( findStumpThreshold_32s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        (uchar*) t_idx, sizeof( int ), tk,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        case CV_32FC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                tk = 0;
                                for( tj = 0; tj < sortedm; tj++ )
                                {
                                    int curidx = (int) ( *((float*) (sorteddata
                                            + ti * sortedcstep + tj * sortedsstep)) );
                                    if( filter[curidx] != 0 )
                                    {
                                        t_idx[tk++] = curidx;
                                    }
                                }
                                if( findStumpThreshold_32s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        (uchar*) t_idx, sizeof( int ), tk,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        default:
                            assert( 0 );
                            break;
                    }
                }
                else            //所有训练样本均参与计算
                {
                    switch( sortedtype )
                    {
                        case CV_16SC1:/*遍历特征寻找使左右误差最小的特征*/
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                /*
                                    t_data + ti * t_cstep:第ti个特征模版
                                    t_sstep:特征模版存储的跨距
                                    wdata:样本的权重 
                                    wstep:样本权重数组的跨距
                                    ydata:样本的类别标签
                                    ystep:样本的类别标签数组的跨距
                                    sorteddata + ti * sortedcstep:第ti个样本排序好的特征值的序号
                                    sortedsstep:跨距
                                    sortedm:序号的列数也就是实际样本列数
                                    lerror:阈值左侧误差
                                    rerror:阈值右侧误差
                                    threshold:阈值
                                    left:左分支置信度
                                    right:右分支置信度
                                    optcompidx:最优特征
                                */
                                if( findStumpThreshold_16s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        sorteddata + ti * sortedcstep, sortedsstep, sortedm,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        case CV_32SC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                if( findStumpThreshold_32s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        sorteddata + ti * sortedcstep, sortedsstep, sortedm,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        case CV_32FC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                if( findStumpThreshold_32f[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        sorteddata + ti * sortedcstep, sortedsstep, sortedm,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        default:
                            assert( 0 );
                            break;
                    }
                }
            }

            /* 新特征部分,要对样本特征值进行排序,然后再寻找最优特征 */
            ti = MAX( t_compidx, MIN( sortedn, t_compidx + t_n ) );
            for( ; ti < t_compidx + t_n; ti++ )
            {
                va.data = t_data + ti * t_cstep;
                va.step = t_sstep;

                // 对样本特征值进行排序
                icvSortIndexedValArray_32s( t_idx, l, &va );

                // 继续寻找最优特征
                if( findStumpThreshold_32s[stumperror]( 
                        t_data + ti * t_cstep, t_sstep,
                        wdata, wstep, ydata, ystep,
                        (uchar*)t_idx, sizeof( int ), l,
                        &lerror, &rerror,
                        &threshold, &left, &right, 
                        &sumw, &sumwy, &sumwyy ) )
                {
                    optcompidx = ti;
                }
            }
        #ifdef _OPENMP
        #pragma omp critical(c_compidx)
        #endif /* _OPENMP */

            // 更新特征计算范围
            {
                t_compidx = compidx;
                compidx += portion;
            }
        }

    #ifdef _OPENMP
    #pragma omp critical(c_beststump)
    #endif /* _OPENMP */

        // 设置最优弱分类器
        {
            if( lerror + rerror < stump->lerror + stump->rerror )
            {
                stump->lerror    = lerror;       
                stump->rerror    = rerror;       
                stump->compidx   = optcompidx;   
                stump->threshold = threshold;    
                stump->left      = left;         
                stump->right     = right;        
            }
        }

        /* free allocated memory */
        if( mat.data.ptr != NULL )
        {
            cvFree( &(mat.data.ptr) );
        }
        if( t_idx != NULL )
        {
            cvFree( &t_idx );
        }
    } /* end of parallel region */

    /* END */

    /* free allocated memory */
    if( filter != NULL )
    {
        cvFree( &filter );
    }
    // 如果设置为离散型,置信度应为1或者-1
    if( ((CvMTStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )  /*要满足这个条件才转成离散*/
    {
        stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;   /*在这里将计算出来的左右置信度浮点数转成1或-1*/
        stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
    }

    return (CvClassifier*) stump;
}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 541
  • 542
  • 543
  • 544
  • 545
  • 546
  • 547
  • 548
  • 549
  • 550
  • 551
  • 552
  • 553
  • 554
  • 555
  • 556
  • 557
  • 558
  • 559
  • 560
  • 561
  • 562
  • 563
  • 564
  • 565
  • 566
  • 567
  • 568
  • 569
  • 570
  • 571
  • 572
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 541
  • 542
  • 543
  • 544
  • 545
  • 546
  • 547
  • 548
  • 549
  • 550
  • 551
  • 552
  • 553
  • 554
  • 555
  • 556
  • 557
  • 558
  • 559
  • 560
  • 561
  • 562
  • 563
  • 564
  • 565
  • 566
  • 567
  • 568
  • 569
  • 570
  • 571
  • 572

从上面函数的代码中观察到在遍历特征时会调用findStumpThreshold_16s、findStumpThreshold_32s、findStumpThreshold_32f数组中定义了的总共12个函数指针,根据参数的不同调用不同的函数,例如findStumpThreshold_16s中的四个函数指针如下:

/*这个数组的类型是一个函数指针*/
CvFindThresholdFunc findStumpThreshold_16s[4] = {
        icvFindStumpThreshold_misc_16s,
        icvFindStumpThreshold_gini_16s,
        icvFindStumpThreshold_entropy_16s,
        icvFindStumpThreshold_sq_16s
    };
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

例如函数指针icvFindStumpThreshold_misc_16s我并未找到函数实现在哪,

在这些数组声明的上面声明的一些宏,其实实现这些函数的

举个例子 
宏1

/* misclassification error
 * err = MIN( wpos, wneg );
 */
#define ICV_DEF_FIND_STUMP_THRESHOLD_MISC( suffix, type )                                \
    ICV_DEF_FIND_STUMP_THRESHOLD( misc_##suffix, type,                                   \
        wposl = 0.5F * ( wl + wyl );                                                     \
        wposr = 0.5F * ( wr + wyr );                                                     \
        curleft = 0.5F * ( 1.0F + curleft );                                             \
        curright = 0.5F * ( 1.0F + curright );                                           \
        curlerror = MIN( wposl, wl - wposl );                                            \
        currerror = MIN( wposr, wr - wposr );                                            \
    )
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

宏2

#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                              \
        uchar* data, size_t datastep,                                                    \
        uchar* wdata, size_t wstep,                                                      \
        uchar* ydata, size_t ystep,                                                      \
        uchar* idxdata, size_t idxstep, int num,                                         \
        float* lerror,                                                                   \
        float* rerror,                                                                   \
        float* threshold, float* left, float* right,                                     \
        float* sumw, float* sumwy, float* sumwyy )   
{

。。。。。。。。。。。。。。。。

}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

根据宏在预编译阶段的解析原理,在预编译阶段声明: 
ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 16s, short )时会被上述第一个宏代替,变成

ICV_DEF_FIND_STUMP_THRESHOLD( misc_16s, short,                                   \
        wposl = 0.5F * ( wl + wyl );                                                     \
        wposr = 0.5F * ( wr + wyr );                                                     \
        curleft = 0.5F * ( 1.0F + curleft );                                             \
        curright = 0.5F * ( 1.0F + curright );                                           \
        curlerror = 2.0F * wposl * ( 1.0F - curleft );                                   \
        currerror = 2.0F * wposr * ( 1.0F - curright );                                  \
    )
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

接着上述ICV_DEF_FIND_STUMP_THRESHOLD宏解析后被CV_BOOST_IMPL int icvFindStumpThreshold_##suffix替代变成:

CV_BOOST_IMPL int icvFindStumpThreshold_misc_16s(                                              \
        uchar* data, size_t datastep,                                                    \
        uchar* wdata, size_t wstep,                                                      \
        uchar* ydata, size_t ystep,                                                      \
        uchar* idxdata, size_t idxstep, int num,                                         \
        float* lerror,                                                                   \
        float* rerror,                                                                   \
        float* threshold, float* left, float* right,                                     \
        float* sumw, float* sumwy, float* sumwyy )  
{

。。。。。。。。。。

/*后面函数体中的所有"type"都被替换成"short"*/

/*函数体中的"error"被*/
        wposl = 0.5F * ( wl + wyl );                                                     \
        wposr = 0.5F * ( wr + wyr );                                                     \
        curleft = 0.5F * ( 1.0F + curleft );                                             \
        curright = 0.5F * ( 1.0F + curright );                                           \
        curlerror = 2.0F * wposl * ( 1.0F - curleft );                                   \
        currerror = 2.0F * wposr * ( 1.0F - curright );  
       /*替换*/

} 
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

这样就实现了icvFindStumpThreshold_misc_16s的函数定义,其他的声明也是同样的方法,这就很巧妙的将12个很相似的函数,用宏声明的方式给分别定义了,而不用重复写很多代码。

所以findStumpThreshold_16s、findStumpThreshold_32s、findStumpThreshold_32f数组中函数指针指定的函数都是由下述宏实现的,只是要改下参数和error的计算方法::::

#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                              \
        uchar* data, size_t datastep,                                                    \
        uchar* wdata, size_t wstep,                                                      \
        uchar* ydata, size_t ystep,                                                      \
        uchar* idxdata, size_t idxstep, int num,                                         \
        float* lerror,                                                                   \
        float* rerror,                                                                   \
        float* threshold, float* left, float* right,                                     \
        float* sumw, float* sumwy, float* sumwyy ) 
{                                                                                        \
    int found = 0;                                                                       \
    float wyl  = 0.0F;                                                                   \
    float wl   = 0.0F;                                                                   \
    float wyyl = 0.0F;                                                                   \
    float wyr  = 0.0F;                                                                   \
    float wr   = 0.0F;                                                                   \
                                                                                         \
    float curleft  = 0.0F;                                                               \
    float curright = 0.0F;                                                               \
    float* prevval = NULL;                                                               \
    float* curval  = NULL;                                                               \
    float curlerror = 0.0F;                                                              \
    float currerror = 0.0F;                                                              \
    float wposl;                                                                         \
    float wposr;                                                                         \
                                                                                         \
    int i = 0;                                                                           \
    int idx = 0;                                                                         \
                                                                                         \
    wposl = wposr = 0.0F;                                                                \
    if( *sumw == FLT_MAX )                                                               \
    {                                                                                    \
        /* calculate sums */                                                             \
        float *y = NULL;                                                                 \
        float *w = NULL;                                                                 \
        float wy = 0.0F;                                                                 \
                                                                                         \
        *sumw   = 0.0F;                                                                  \
        *sumwy  = 0.0F;                                                                  \
        *sumwyy = 0.0F;                                                                  \
        for( i = 0; i < num; i++ )                                                       \
        {                                                                                \
            idx = (int) ( *((type*) (idxdata + i*idxstep)) );                            \
            w = (float*) (wdata + idx * wstep);                                          \
            *sumw += *w;                                                                 \ 
            y = (float*) (ydata + idx * ystep);                                          \
            wy = (*w) * (*y);                                                            \
            *sumwy += wy;                                                                \
            *sumwyy += wy * (*y);                                                        \
        }                                                                                \
    } 

    /*num:实际样本个数,遍历样本找到使左右误差最小的阈值curval位置*/                                                                                     \
    for( i = 0; i < num; i++ )                                                           \
    {                                                                                    \
        idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                \
        curval = (float*) (data + idx * datastep);                                       \
         /* for debug purpose */                                                         \
        if( i > 0 ) assert( (*prevval) <= (*curval) );                                   \
                                                                                         \
        wyr  = *sumwy - wyl;                                                             \
        wr   = *sumw  - wl;                                                              \
                                                                                         \
        if( wl > 0.0 ) curleft = wyl / wl;                                               \
        else curleft = 0.0F;                                                             \
                                                                                         \
        if( wr > 0.0 ) curright = wyr / wr;                                              \
        else curright = 0.0F;                                                            \
                                                                                         \
        error                                                                            \
                                                                                         \
        if( curlerror + currerror < (*lerror) + (*rerror) )                              \
        {                                                                                \
            (*lerror) = curlerror;                                                       \
            (*rerror) = currerror;                                                       \
            *threshold = *curval;                                                        \
            if( i > 0 ) {                                                                \
                *threshold = 0.5F * (*threshold + *prevval);                             \
            }                                                                            \
            *left  = curleft;                                                            \
            *right = curright;                                                           \
            found = 1;                                                                   \
        }                                                                                \
                                                                                         \
        do                                                                               \
        {                                                                                \
            wl  += *((float*) (wdata + idx * wstep));                                    \
            wyl += (*((float*) (wdata + idx * wstep)))                                   \
                * (*((float*) (ydata + idx * ystep)));                                   \
            wyyl += *((float*) (wdata + idx * wstep))                                    \
                * (*((float*) (ydata + idx * ystep)))                                    \
                * (*((float*) (ydata + idx * ystep)));                                   \
        }                                                                                \
        while( (++i) < num &&                                                            \
            ( *((float*) (data + (idx =                                                  \
                (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                 \
                == *curval ) );                                                          \
        --i;                                                                             \
        prevval = curval;                                                                \
    } /* for each value */                                                               \
                                                                                         \
    return found;                                                                        \
}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104

这段代码的逻辑一句话概括就是:::遍历某特征的所有样本找到使分类的左右误差最小的阈值。


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值