人脸检测源码解析——6、强分类器的诞生

本文深入解析OpenCV中的人脸检测算法,重点关注Adaboost如何结合Haar特征训练强分类器的过程。通过CvCascadeBoost::train函数的详细分析,揭示了目标检测技术的关键步骤。
摘要由CSDN通过智能技术生成

       下面的内容很长,倒杯水(有茶或者咖啡更好),带上耳机,准备就绪再往下看。下面我们来看强分类器是如何训练的,该过程在CvCascadeBoost::train函数中完成,代码如下:

bool CvCascadeBoost::train( const CvFeatureEvaluator* _featureEvaluator,
                           int _numSamples,
                           int _precalcValBufSize, int _precalcIdxBufSize,
                           const CvCascadeBoostParams& _params )
{
    bool isTrained = false;
    CV_Assert( !data );
    clear();
	// 样本的数据都存在 _featureEvaluator 里面,这里把训练相关的数据都
	// 用CvCascadeBoostTrainData类封装,内部创建了运行时需要的一些内存
	// 方便后面使用
    data = new CvCascadeBoostTrainData( _featureEvaluator, _numSamples,
                                        _precalcValBufSize, _precalcIdxBufSize, _params );
    CvMemStorage *storage = cvCreateMemStorage();
	// 创建一个 CvSeq 序列,存放一个强分类器的所有弱分类器
    weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
    storage = 0;


    set_params( _params );
    if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
	{
		// 从_featureEvaluator->cls 中拷贝样本的类别信息到 data->responses
		// 因为这两种boost方法计算式把类别从0/1该为-1/+1使用
		data->do_responses_copy();
	}
	// 设置所有样本初始权值为1/n
    update_weights( 0 );


    cout << "+----+---------+---------+" << endl;
    cout << "|  N |    HR   |    FA   |" << endl;
    cout << "+----+---------+---------+" << endl;


    do
    {
		// 训练一个弱分类器,弱分类器是棵CART树
        CvCascadeBoostTree* tree = new CvCascadeBoostTree;
        if( !tree->train( data, subsample_mask, this ) )
        {
            delete tree;
            break;
        }
		// 得到弱分类器加入序列
        cvSeqPush( weak, &tree );
		// 根据boost公式更新样本数据的权值
        update_weights( tree );
		// 根据用户输入参数,把一定比例的(0.05)权值最小的样本去掉
        trim_weights();
		// subsample_mask 保存每个样本是否参数训练的标记(值为0/1)
		// 没有可用样本了,退出训练
        if( cvCountNonZero(subsample_mask) == 0 )
            break;
    } // 如果当前强分类器达到了设置的虚警率要求或弱分类数目达到上限停止
    while( !isErrDesired() && (weak->total < params.weak_count) );


    if(weak->total > 0)
    {
        data->is_classifier = true;
        data->free_train_data();
        isTrained = true;
    }
    else
        clear();


    return isTrained;
}


        代码中首先把训练相关的数据用CvCascadeBoostTrainData封装,一遍后面传递给其它函数,将每个样本的权值设置为1/N,N为总样本数。此后便开始进入弱分类器训练循环。我们接着来看弱分类器的训练,代码位于CvCascadeBoostTree::train中。
bool
CvBoostTree::train( CvDTreeTrainData* _train_data,
                    const CvMat* _subsample_idx, CvBoost* _ensemble )
{
    clear();
    ensemble = _ensemble;
    data = _train_data;
    data->shared = true;
    return do_train( _subsample_idx );
}

        注意这里的参数_ensemble实际是CvCascadeBoost类型的指针,转入调用CvBoostTree::do_train函数,传入的参数为参与训练的样本的索引数组,具体代码如下:
bool CvDTree::do_train( const CvMat* _subsample_idx )
{
    bool result = false;


    CV_FUNCNAME( "CvDTree::do_train" );


    __BEGIN__;
	// 创建CART树根节点,设置根节点是数据为输入数据集
    root = data->subsample_data( _subsample_idx );
	// 开始分割节点,向树上增加子节点,构成CART树。如果设置弱分类器
    CV_CALL( try_split_node(root));


    if( root->split )
    {
        CV_Assert( root->left );
        CV_Assert( root->right );


        if( data->params.cv_folds > 0 )
            CV_CALL( prune_cv() );


        if( !data->shared )
            data->free_train_data();


        result = true;
    }


    __END__;


    return result;
}

        创建一个root节点后,对root节点进行分割,调用try_split_node函数实现,代码如下:
void CvDTree::try_split_node( CvDTreeNode* node )
{
    CvDTreeSplit* best_split = 0;
    int i, n = node->sample_count, vi;
    bool can_split = true;
    double quality_scale;
	// 计算当前节点的 value,节点的风险 node_risk
    calc_node_value( node );
	// 节点样本数目过少样本数(默认为10) 或者树深度达到设置值(默认为1),也就是一个分割节点
    if( node->sample_count <= data->params.min_sample_count ||
        node->depth >= data->params.max_depth )
        can_split = false;
	// is_classifer:false
    if( can_split && data->is_classifier )
    {
        // check if we have a "pure" node,
        // we assume that cls_count is filled by calc_node_value()
        int* cls_count = data->counts->data.i;
        int nz = 0, m = data->get_num_classes();
        for( i = 0; i < m; i++ )
            nz += cls_count[i] != 0;
        if( nz == 1 ) // there is only one class
            can_split = false;
    }
    else if( can_split )
    {
		// 平均error值很小了,说明已经分得很好,没必要继续下去 regression_accuracy (0.01)
        if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
            can_split = false;
    }


    if( can_split )
    {
		// 调用函数找到最优分割,弱分类器训练的重头戏
        best_split = find_best_split(node);
        // TODO: check the split quality ...
        node->split = best_split;
    }
    if( !can_split || !best_split )
    {
        data->free_node_data(node);
        return;
    }
	// ignore this
    quality_scale = calc_node_dir( node );
	// 级联参数 use_surrogates = use_1se_rule = truncate_pruned_tree = false;
    if( data->params.use_surrogates )
    {
        // find all the surrogate splits
        // and sort them by their similarity to the primary one
        for( vi = 0; vi < data->var_count; vi++ )
        {
            CvDTreeSplit* split;
            int ci = data->get_var_type(vi);


            if( vi == best_split->var_idx )
                continue;


            if( ci >= 0 )
                split = find_surrogate_split_cat( node, vi );
            else
                split = find_surrogate_split_ord( node, vi );


            if( split )
            {
                // insert the split
                CvDTreeSplit* prev_split = node->split;
                split->quality 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值