决策树(Decision Tree
)的原理在这里就不赘述了,
只要知道在OpenCV3.0中用的算法不是ID3,C4.5,
使用的是GINI
指标来衡量不纯度
注意一下与以前的版本中使用上有些差别,
老的写法CvDTrees*
参考 OpenCV目录中的 samples\c\mushroom.cpp
新的写法 Ptr<DTrees>
参考 samples\cpp\tree_engine.cpp
名字:
cv::ml::DTrees
所属头文件:
#include "opencv2/ml/ml.hpp"
创建
Ptr<DTrees> dtree = DTrees::create();
训练
bool ok = model->train(data);
/*
Ptr<StatModel> model
这里的继承顺序是Algorithm<--StatModel<--DTrees<--RTrees,
也就是说其他算法如SVM,Boost,RTree(Random Trees Classifier)都可以用这种方法来train
*/
/** @brief Trains the statistical model
@param trainData training data that can be loaded from file using TrainData::loadFromCSV or
created with TrainData::create.
@param flags optional flags, depending on the model. Some of the models can be updated with the
new training samples, not completely overwritten (such as NormalBayesClassifier or ANN_MLP).
*/
CV_WRAP virtual bool train( const Ptr<TrainData>& trainData, int flags=0 );
/*
此处data可以从CSV中读出,也就是我们常用的excel格式,
这个可以用文本编辑器打开,格式简单,
但因为之前没经验,也没看到别人提及数据样例的问题,
所以我在最后会附上一个demo和相应的测试数据
*/
保存训练结果
dtree->save("trained_dtree.xml");
读取训练结果
string dtreeFileName("trained_dtree.xml");
Ptr<ml::DTrees> dtree = Algorithm::load<DTrees>(dtreeFileName);
/*
这里注意与dtree->save()不同,
dtree->save()会出现dtree的root为空的错误,
一定要用Algorithm::load<DTrees>(dtreeFileName);
或者 StatModel::load<DTrees>(dtreeFileName);
或者 DTrees::load<DTrees>(dtreeFileName);
将指针传出来
*/
使用决策树分类
vector<float> testVec;
testVec.push_back(1);//将其他数据放入
float resultKind = dtree->predict(testVec);
/** @brief Predicts response(s) for the provided sample(s)
@param samples The input samples, floating-point matrix
@param results The optional output matrix of results.
@param flags The optional flags, model-dependent. See cv::ml::StatModel::Flags.
*/
CV_WRAP virtual float predict( InputArray samples, OutputArray results=noArray(), int flags=0 ) const = 0;
demo
修改了sample中tree_engine.cpp的一部分,有OpenCV3.0的环境可以运行
tree_engine.cpp
#include "opencv2/ml/ml.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/core/utility.hpp"
#include <stdio.h>
#include <string>
#include <map>
#include <vector>
using namespace cv;
using namespace cv::ml;
static void help()
{
printf(
"\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
"Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
"where -r <response_column> specified the 0-based index of the response (0 by default)\n"
"-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
"<csv filename> is the name of training data file in comma-separated value format\n\n");
}
static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
{
bool ok = model->train(data);
if( !ok )
{
printf("Training failed\n");
}
else
{
printf( "train error: %f\n", model->calcError(data, false, noArray()) );
printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
}
}
int main(int argc, char** argv)
{
if(argc < 2)
{
help();
return 0;
}
const char* filename = 0;
int response_idx = 0;
std::string typespec;
for(int i = 1; i < argc; i++)
{
if(strcmp(argv[i], "-r") == 0)
sscanf(argv[++i], "%d", &response_idx);
else if(strcmp(argv[i], "-ts") == 0)
typespec = argv[++i];
else if(argv[i][0] != '-' )
filename = argv[i];
else
{
printf("Error. Invalid option %s\n", argv[i]);
help();
return -1;
}
}
printf("\nReading in %s...\n\n",filename);
const double train_test_split_ratio = 0.5;
Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
if( data.empty() )
{
printf("ERROR: File %s can not be read\n", filename);
return 0;
}
data->setTrainTestSplitRatio(train_test_split_ratio);
printf("======DTREE=====\n");
Ptr<DTrees> dtree = DTrees::create();
dtree->setMaxDepth(10);
dtree->setMinSampleCount(2);
dtree->setRegressionAccuracy(0);
dtree->setUseSurrogates(false);
dtree->setMaxCategories(16);
dtree->setCVFolds(0);
dtree->setUse1SERule(false);
dtree->setTruncatePrunedTree(false);
dtree->setPriors(Mat());
train_and_print_errs(dtree, data);
dtree->save("dtree_result.xml");
if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
{
printf("======BOOST=====\n");
Ptr<Boost> boost = Boost::create();
boost->setBoostType(Boost::GENTLE);
boost->setWeakCount(100);
boost->setWeightTrimRate(0.95);
boost->setMaxDepth(2);
boost->setUseSurrogates(false);
boost->setPriors(Mat());
train_and_print_errs(boost, data);
}
printf("======RTREES=====\n");
Ptr<RTrees> rtrees = RTrees::create();
rtrees->setMaxDepth(10);
rtrees->setMinSampleCount(2);
rtrees->setRegressionAccuracy(0);
rtrees->setUseSurrogates(false);
rtrees->setMaxCategories(16);
rtrees->setPriors(Mat());
rtrees->setCalculateVarImportance(false);
rtrees->setActiveVarCount(0);
rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
train_and_print_errs(rtrees, data);
std::cout << "======TEST====="<<std::endl;
Ptr<DTrees> dtree2 = DTrees::load<DTrees>("dtree_result.xml");
std::vector<float>testVec;
testVec.push_back(1);
testVec.push_back(6);
float resultKind = dtree2->predict(testVec);
std::cout << "1,6:"<<resultKind<<std::endl;
return 0;
}
showData.csv
这里的第一列表示类别,后面两列表示两维特征(比如x,y),
所以此处的数据总共有两类
1,1,2
1,1,3
1,2,4
1,1,4
1,1,5
1,2,5
2,1,20
2,4,2
2,3,4
2,2,25
2,1,40
2,4,1
2,3,3
2,2,80
1,1,2
1,1,3
1,2,4
1,1,4
1,1,5
1,2,5
2,1,20
2,4,2
2,3,4
2,2,25
2,1,40
2,4,1
2,3,3
2,2,80
使用方法:
在VS菜单中选择调试
–><项目名>属性
–>
配置属性
–> 调试
–>命令参数
设置为 showData.csv
放在项目路径下(也就是.vcxproj所在的目录),运行即可
最终会产生一个dtree_result.xml
的训练结果文件:
<?xml version="1.0"?>
<opencv_storage>
<opencv_ml_dtree>
<format>3</format>
<is_classifier>1</is_classifier>
<var_all>3</var_all>
<var_count>2</var_count>
<ord_var_count>2</ord_var_count>
<cat_var_count>1</cat_var_count>
<training_params>
<use_surrogates>0</use_surrogates>
<max_categories>15</max_categories>
<regression_accuracy>0.</regression_accuracy>
<max_depth>10</max_depth>
<min_sample_count>2</min_sample_count>
<cross_validation_folds>0</cross_validation_folds></training_params>
<global_var_idx>1</global_var_idx>
<var_idx>
0 1</var_idx>
<var_type>
0 0 1</var_type>
<cat_ofs>
0 0 0 0</cat_ofs>
<class_labels>
1 2</class_labels>
<missing_subst>
0. 0. 0.</missing_subst>
<nodes>
<_>
<depth>0</depth>
<value>2.</value>
<norm_class_idx>1</norm_class_idx>
<splits>
<_><var>0</var>
<quality>9.1999998092651367e+000</quality>
<le>2.5000000000000000e+000</le></_></splits></_>
<_>
<depth>1</depth>
<value>1.</value>
<norm_class_idx>0</norm_class_idx>
<splits>
<_><var>1</var>
<quality>10.</quality>
<le>1.2500000000000000e+001</le></_></splits></_>
<_>
<depth>2</depth>
<value>1.</value>
<norm_class_idx>0</norm_class_idx></_>
<_>
<depth>2</depth>
<value>2.</value>
<norm_class_idx>1</norm_class_idx></_>
<_>
<depth>1</depth>
<value>2.</value>
<norm_class_idx>1</norm_class_idx></_></nodes></opencv_ml_dtree>
</opencv_storage>
这里便是决策树的全貌,
经过阅读可以发现它是按照先序遍历
的(也就是根左右)
depth为2,其中
<var>0</var>
表示第一维变量,
<le>1.2500000000000000e+001</le>
不确定表示的是小于还是小于等于,
我没有找到相关的定义文档,知道的朋友欢迎补充.
类推可以自己画出这棵树.