OpenCV3.0 决策树的使用

决策树(Decision Tree)的原理在这里就不赘述了,
只要知道在OpenCV3.0中用的算法不是ID3,C4.5,
使用的是GINI指标来衡量不纯度

注意一下与以前的版本中使用上有些差别,
老的写法CvDTrees* 参考 OpenCV目录中的 samples\c\mushroom.cpp
新的写法 Ptr<DTrees> 参考 samples\cpp\tree_engine.cpp


名字:

cv::ml::DTrees

所属头文件:

#include "opencv2/ml/ml.hpp"

创建

Ptr<DTrees> dtree = DTrees::create();

训练

bool ok = model->train(data);

/*
Ptr<StatModel> model
这里的继承顺序是Algorithm<--StatModel<--DTrees<--RTrees,
也就是说其他算法如SVM,Boost,RTree(Random Trees Classifier)都可以用这种方法来train
*/

/** @brief Trains the statistical model

    @param trainData training data that can be loaded from file using TrainData::loadFromCSV or
        created with TrainData::create.
    @param flags optional flags, depending on the model. Some of the models can be updated with the
        new training samples, not completely overwritten (such as NormalBayesClassifier or ANN_MLP).
     */
    CV_WRAP virtual bool train( const Ptr<TrainData>& trainData, int flags=0 );

/*
此处data可以从CSV中读出,也就是我们常用的excel格式,
这个可以用文本编辑器打开,格式简单,
但因为之前没经验,也没看到别人提及数据样例的问题,
所以我在最后会附上一个demo和相应的测试数据
*/

保存训练结果

dtree->save("trained_dtree.xml");

读取训练结果

string dtreeFileName("trained_dtree.xml");
Ptr<ml::DTrees> dtree = Algorithm::load<DTrees>(dtreeFileName);
/*
这里注意与dtree->save()不同,
dtree->save()会出现dtree的root为空的错误,
一定要用Algorithm::load<DTrees>(dtreeFileName);
或者   StatModel::load<DTrees>(dtreeFileName);
或者   DTrees::load<DTrees>(dtreeFileName);
将指针传出来
*/

使用决策树分类

vector<float> testVec;
testVec.push_back(1);//将其他数据放入
float resultKind = dtree->predict(testVec);


/** @brief Predicts response(s) for the provided sample(s)

    @param samples The input samples, floating-point matrix
    @param results The optional output matrix of results.
    @param flags The optional flags, model-dependent. See cv::ml::StatModel::Flags.
     */
    CV_WRAP virtual float predict( InputArray samples, OutputArray results=noArray(), int flags=0 ) const = 0;


demo

修改了sample中tree_engine.cpp的一部分,有OpenCV3.0的环境可以运行
tree_engine.cpp

#include "opencv2/ml/ml.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/core/utility.hpp"
#include <stdio.h>
#include <string>
#include <map>
#include <vector>
using namespace cv;
using namespace cv::ml;

static void help()
{
    printf(
        "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
        "Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
        "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
        "-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
        "<csv filename> is the name of training data file in comma-separated value format\n\n");
}

static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
{
    bool ok = model->train(data);
    if( !ok )
    {
        printf("Training failed\n");
    }
    else
    {
        printf( "train error: %f\n", model->calcError(data, false, noArray()) );
        printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
    }
}

int main(int argc, char** argv)
{
    if(argc < 2)
    {
        help();
        return 0;
    }
    const char* filename = 0;
    int response_idx = 0;
    std::string typespec;

    for(int i = 1; i < argc; i++)
    {
        if(strcmp(argv[i], "-r") == 0)
            sscanf(argv[++i], "%d", &response_idx);
        else if(strcmp(argv[i], "-ts") == 0)
            typespec = argv[++i];
        else if(argv[i][0] != '-' )
            filename = argv[i];
        else
        {
            printf("Error. Invalid option %s\n", argv[i]);
            help();
            return -1;
        }
    }

    printf("\nReading in %s...\n\n",filename);
    const double train_test_split_ratio = 0.5;

    Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);

    if( data.empty() )
    {
        printf("ERROR: File %s can not be read\n", filename);
        return 0;
    }

    data->setTrainTestSplitRatio(train_test_split_ratio);

    printf("======DTREE=====\n");
    Ptr<DTrees> dtree = DTrees::create();
    dtree->setMaxDepth(10);
    dtree->setMinSampleCount(2);
    dtree->setRegressionAccuracy(0);
    dtree->setUseSurrogates(false);
    dtree->setMaxCategories(16);
    dtree->setCVFolds(0);
    dtree->setUse1SERule(false);
    dtree->setTruncatePrunedTree(false);
    dtree->setPriors(Mat());
    train_and_print_errs(dtree, data);
    dtree->save("dtree_result.xml");

    if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
    {
        printf("======BOOST=====\n");
        Ptr<Boost> boost = Boost::create();
        boost->setBoostType(Boost::GENTLE);
        boost->setWeakCount(100);
        boost->setWeightTrimRate(0.95);
        boost->setMaxDepth(2);
        boost->setUseSurrogates(false);
        boost->setPriors(Mat());
        train_and_print_errs(boost, data);
    }

    printf("======RTREES=====\n");
    Ptr<RTrees> rtrees = RTrees::create();
    rtrees->setMaxDepth(10);
    rtrees->setMinSampleCount(2);
    rtrees->setRegressionAccuracy(0);
    rtrees->setUseSurrogates(false);
    rtrees->setMaxCategories(16);
    rtrees->setPriors(Mat());
    rtrees->setCalculateVarImportance(false);
    rtrees->setActiveVarCount(0);
    rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
    train_and_print_errs(rtrees, data);

    std::cout << "======TEST====="<<std::endl;
    Ptr<DTrees> dtree2 = DTrees::load<DTrees>("dtree_result.xml");
    std::vector<float>testVec;
    testVec.push_back(1);
    testVec.push_back(6);
    float resultKind = dtree2->predict(testVec);
    std::cout << "1,6:"<<resultKind<<std::endl;
    return 0;
}

showData.csv
这里的第一列表示类别,后面两列表示两维特征(比如x,y),
所以此处的数据总共有两类

1,1,2
1,1,3
1,2,4
1,1,4
1,1,5
1,2,5
2,1,20
2,4,2
2,3,4
2,2,25
2,1,40
2,4,1
2,3,3
2,2,80
1,1,2
1,1,3
1,2,4
1,1,4
1,1,5
1,2,5
2,1,20
2,4,2
2,3,4
2,2,25
2,1,40
2,4,1
2,3,3
2,2,80

使用方法:
在VS菜单中选择调试–><项目名>属性–>
配置属性–> 调试–>命令参数 设置为 showData.csv
放在项目路径下(也就是.vcxproj所在的目录),运行即可

最终会产生一个dtree_result.xml的训练结果文件:

<?xml version="1.0"?>
<opencv_storage>
<opencv_ml_dtree>
  <format>3</format>
  <is_classifier>1</is_classifier>
  <var_all>3</var_all>
  <var_count>2</var_count>
  <ord_var_count>2</ord_var_count>
  <cat_var_count>1</cat_var_count>
  <training_params>
    <use_surrogates>0</use_surrogates>
    <max_categories>15</max_categories>
    <regression_accuracy>0.</regression_accuracy>
    <max_depth>10</max_depth>
    <min_sample_count>2</min_sample_count>
    <cross_validation_folds>0</cross_validation_folds></training_params>
  <global_var_idx>1</global_var_idx>
  <var_idx>
    0 1</var_idx>
  <var_type>
    0 0 1</var_type>
  <cat_ofs>
    0 0 0 0</cat_ofs>
  <class_labels>
    1 2</class_labels>
  <missing_subst>
    0. 0. 0.</missing_subst>
  <nodes>
    <_>
      <depth>0</depth>
      <value>2.</value>
      <norm_class_idx>1</norm_class_idx>
      <splits>
        <_><var>0</var>
          <quality>9.1999998092651367e+000</quality>
          <le>2.5000000000000000e+000</le></_></splits></_>
    <_>
      <depth>1</depth>
      <value>1.</value>
      <norm_class_idx>0</norm_class_idx>
      <splits>
        <_><var>1</var>
          <quality>10.</quality>
          <le>1.2500000000000000e+001</le></_></splits></_>
    <_>
      <depth>2</depth>
      <value>1.</value>
      <norm_class_idx>0</norm_class_idx></_>
    <_>
      <depth>2</depth>
      <value>2.</value>
      <norm_class_idx>1</norm_class_idx></_>
    <_>
      <depth>1</depth>
      <value>2.</value>
      <norm_class_idx>1</norm_class_idx></_></nodes></opencv_ml_dtree>
</opencv_storage>

这里便是决策树的全貌,
经过阅读可以发现它是按照先序遍历的(也就是根左右)
depth为2,其中
<var>0</var> 表示第一维变量,
<le>1.2500000000000000e+001</le> 不确定表示的是小于还是小于等于,
我没有找到相关的定义文档,知道的朋友欢迎补充.
类推可以自己画出这棵树.

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值