下面的内容很长,倒杯水(有茶或者咖啡更好),带上耳机,准备就绪再往下看。下面我们来看强分类器是如何训练的,该过程在CvCascadeBoost::train函数中完成,代码如下:
bool CvCascadeBoost::train( const CvFeatureEvaluator* _featureEvaluator,
int _numSamples,
int _precalcValBufSize, int _precalcIdxBufSize,
const CvCascadeBoostParams& _params )
{
bool isTrained = false;
CV_Assert( !data );
clear();
// 样本的数据都存在 _featureEvaluator 里面,这里把训练相关的数据都
// 用CvCascadeBoostTrainData类封装,内部创建了运行时需要的一些内存
// 方便后面使用
data = new CvCascadeBoostTrainData( _featureEvaluator, _numSamples,
_precalcValBufSize, _precalcIdxBufSize, _params );
CvMemStorage *storage = cvCreateMemStorage();
// 创建一个 CvSeq 序列,存放一个强分类器的所有弱分类器
weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
storage = 0;
set_params( _params );
if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
{
// 从_featureEvaluator->cls 中拷贝样本的类别信息到 data->responses
// 因为这两种boost方法计算式把类别从0/1该为-1/+1使用
data->do_responses_copy();
}
// 设置所有样本初始权值为1/n
update_weights( 0 );
cout << "+----+---------+---------+" << endl;
cout << "| N | HR | FA |" << endl;
cout << "+----+---------+---------+" << endl;
do
{
// 训练一个弱分类器,弱分类器是棵CART树
CvCascadeBoostTree* tree = new CvCascadeBoostTree;
if( !tree->train( data, subsample_mask, this ) )
{
delete tree;
break;
}
// 得到弱分类器加入序列
cvSeqPush( weak, &tree );
// 根据boost公式更新样本数据的权值
update_weights( tree );
// 根据用户输入参数,把一定比例的(0.05)权值最小的样本去掉
trim_weights();
// subsample_mask 保存每个样本是否参数训练的标记(值为0/1)
// 没有可用样本了,退出训练
if( cvCountNonZero(subsample_mask) == 0 )
break;
} // 如果当前强分类器达到了设置的虚警率要求或弱分类数目达到上限停止
while( !isErrDesired() && (weak->total < params.weak_count) );
if(weak->total > 0)
{
data->is_classifier = true;
data->free_train_data();
isTrained = true;
}
else
clear();
return isTrained;
}
代码中首先把训练相关的数据用CvCascadeBoostTrainData封装,一遍后面传递给其它函数,将每个样本的权值设置为1/N,N为总样本数。此后便开始进入弱分类器训练循环。我们接着来看弱分类器的训练,代码位于CvCascadeBoostTree::train中。
bool
CvBoostTree::train( CvDTreeTrainData* _train_data,
const CvMat* _subsample_idx, CvBoost* _ensemble )
{
clear();
ensemble = _ensemble;
data = _train_data;
data->shared = true;
return do_train( _subsample_idx );
}
注意这里的参数_ensemble实际是CvCascadeBoost类型的指针,转入调用CvBoostTree::do_train函数,传入的参数为参与训练的样本的索引数组,具体代码如下:
bool CvDTree::do_train( const CvMat* _subsample_idx )
{
bool result = false;
CV_FUNCNAME( "CvDTree::do_train" );
__BEGIN__;
// 创建CART树根节点,设置根节点是数据为输入数据集
root = data->subsample_data( _subsample_idx );
// 开始分割节点,向树上增加子节点,构成CART树。如果设置弱分类器
CV_CALL( try_split_node(root));
if( root->split )
{
CV_Assert( root->left );
CV_Assert( root->right );
if( data->params.cv_folds > 0 )
CV_CALL( prune_cv() );
if( !data->shared )
data->free_train_data();
result = true;
}
__END__;
return result;
}
创建一个root节点后,对root节点进行分割,调用try_split_node函数实现,代码如下:
void CvDTree::try_split_node( CvDTreeNode* node )
{
CvDTreeSplit* best_split = 0;
int i, n = node->sample_count, vi;
bool can_split = true;
double quality_scale;
// 计算当前节点的 value,节点的风险 node_risk
calc_node_value( node );
// 节点样本数目过少样本数(默认为10) 或者树深度达到设置值(默认为1),也就是一个分割节点
if( node->sample_count <= data->params.min_sample_count ||
node->depth >= data->params.max_depth )
can_split = false;
// is_classifer:false
if( can_split && data->is_classifier )
{
// check if we have a "pure" node,
// we assume that cls_count is filled by calc_node_value()
int* cls_count = data->counts->data.i;
int nz = 0, m = data->get_num_classes();
for( i = 0; i < m; i++ )
nz += cls_count[i] != 0;
if( nz == 1 ) // there is only one class
can_split = false;
}
else if( can_split )
{
// 平均error值很小了,说明已经分得很好,没必要继续下去 regression_accuracy (0.01)
if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
can_split = false;
}
if( can_split )
{
// 调用函数找到最优分割,弱分类器训练的重头戏
best_split = find_best_split(node);
// TODO: check the split quality ...
node->split = best_split;
}
if( !can_split || !best_split )
{
data->free_node_data(node);
return;
}
// ignore this
quality_scale = calc_node_dir( node );
// 级联参数 use_surrogates = use_1se_rule = truncate_pruned_tree = false;
if( data->params.use_surrogates )
{
// find all the surrogate splits
// and sort them by their similarity to the primary one
for( vi = 0; vi < data->var_count; vi++ )
{
CvDTreeSplit* split;
int ci = data->get_var_type(vi);
if( vi == best_split->var_idx )
continue;
if( ci >= 0 )
split = find_surrogate_split_cat( node, vi );
else
split = find_surrogate_split_ord( node, vi );
if( split )
{
// insert the split
CvDTreeSplit* prev_split = node->split;
split->quality