先看OpenCV中与决策树有关的结构。
CvDTreeSplit 表示树节点的一个可能分割。
CvDTreeNode 表示决策树中的一个节点。
CvDTreeParams 包含了训练决策树的所有参数。
CvDTreeTrainData 决策树的训练数据,为树全体共享。
CvDTree 此类实现了决策树,包含了训练/预测等等操作。
就普通用户而言,使用的流程可以用准备数据、训练决策树、利用决策树预测三个环节表示。
下面这个小例子包括了数据的创建、训练、预测和保存几个基本步骤的简单操作。
#include "opencv2/core/core_c.h"
#include "opencv2/ml/ml.hpp"
#include <iostream>
int main()
{
//init data
float fdata[5][2] = {{1,1},{1,1},{1,0},{0,1},{0,1}};
cv::Mat data(5,2,CV_32F,fdata);
float fresponses[5] ={1,1,0,0,0};
cv::Mat responses(5,1,CV_32F,fresponses);
float priors[]={1,1};
CvDTree *tree;
CvDTreeParams params( 8, // max depth
1, // min sample count
0, // regression accuracy: N/A here
true, // compute surrogate split, as we have missing data
15, // max number of categories (use sub-optimal algorithm for larger numbers)
0, // the number of cross-validation folds
true, // use 1SE rule => smaller tree
true, // throw away the pruned tree branches
priors // the array of priors, the bigger p_weight, the more attention
// to the poisonous mushrooms
// (a mushroom will be judjed to be poisonous with bigger chance)
);
tree = new CvDTree;
tree->train (data,CV_ROW_SAMPLE,responses,cv::Mat(),
cv::Mat(),cv::Mat(),cv::Mat(),
params);
//try predict
cv::Mat sample(1,2,CV_32F,cv::Scalar::all (1));
double r = tree->predict (sample,cv::Mat())->value;
std::cout << "r: "<< r << std::endl;
//save tree in the xml file
tree->save ("tree.xml","test_tree");
return 0;
}