class CvCascadeBoost : public CvBoost
{
public:
virtual bool train( const CvFeatureEvaluator* _featureEvaluator,
int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
virtual float predict( int sampleIdx, bool returnSum = false ) const;
float getThreshold() const { return threshold; }
void write( cv::FileStorage &fs, const cv::Mat& featureMap ) const;
bool read( const cv::FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
const CvCascadeBoostParams& _params );
void markUsedFeaturesInMap( cv::Mat& featureMap );
protected:
virtual bool set_params( const CvBoostParams& _params );
virtual void update_weights( CvBoostTree* tree );
virtual bool isErrDesired();
float threshold;
float minHitRate, maxFalseAlarm;
};
//开始训练强分类器。输入分别为:特征向量、正负样本总数、特征值所占内存量,索引所占内存量,参数
bool CvCascadeBoost::train( const CvFeatureEvaluator* _featureEvaluator,
int _numSamples,
int _precalcValBufSize, int _precalcIdxBufSize,
const CvCascadeBoostParams& _params )
{
bool isTrained = false;
CV_Assert( !data );
clear();//清理本对象中的矩阵所占内存等。
//读取数据
data = new CvCascadeBoostTrainData( _featureEvaluator, _numSamples,
_precalcValBufSize, _precalcIdxBufSize, _params );
//弱分类器所需内存
CvMemStorage *storage = cvCreateMemStorage();
weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
storage = 0;
//设置参数
set_params( _params );
if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
data->do_responses_copy();
//更新权重
update_weights( 0 );
cout << "+----+---------+---------+" << endl;
cout << "| N | HR | FA |" << endl;
cout << "+----+---------+---------+" << endl;
do
{
CvCascadeBoostTree* tree = new CvCascadeBoostTree;//弱分类器
if( !tree->train( data, subsample_mask, this ) )//训练
{
delete tree;
break;
}
cvSeqPush( weak, &tree );//保存弱分类器
update_weights( tree );//更新权重
trim_weights();
if( cvCountNonZero(subsample_mask) == 0 )
break;
}
while( !isErrDesired() && (weak->total < params.weak_count) );
if(weak->total > 0)
{
data->is_classifier = true;
data->free_train_data();
isTrained = true;
}
else
clear();
return isTrained;
}