机器学习算法(一)-决策树代码(OpenCV3 )

OpenCV3中关于决策树的代码所在路径:samples\cpp\tree_engine.cpp
opencv官网相关函数解释:
http://docs.opencv.org/3.1.0/d8/d89/classcv_1_1ml_1_1DTrees.html

#include "opencv2/ml/ml.hpp"      //需要添加的头文件
#include "opencv2/core/core.hpp"
#include "opencv2/core/utility.hpp"
#include <stdio.h>
#include <string>
#include <map>

using namespace cv;
using namespace cv::ml;

static void help()
{
    printf(
        "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
        "Usage:\n\t./tree_engine [-r=<response_column>] [-ts=type_spec] <csv filename>\n"
        "where -r=<response_column> specified the 0-based index of the response (0 by default)\n"
        "-ts= specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
        "<csv filename> is the name of training data file in comma-separated value format\n\n");
}

static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
{
    bool ok = model->train(data);
    if( !ok )
    {
        printf("Training failed\n");
    }
    else
    {
        printf( "train error: %f\n", model->calcError(data, false, noArray()) );
        printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
    }
}

int main(int argc, char** argv)
{
    cv::CommandLineParser parser(argc, argv, "{ help h | | }{r | 0 | }{ts | | }{@input | | }");
    if (parser.has("help"))
    {
        help();
        return 0;
    }
    std::string filename = parser.get<std::string>("@input");
    int response_idx;
    std::string typespec;
    response_idx = parser.get<int>("r");
    typespec = parser.get<std::string>("ts");
    if( filename.empty() || !parser.check() )
    {
        parser.printErrors();
        help();
        return 0;
    }
    printf("\nReading in %s...\n\n",filename.c_str());
    const double train_test_split_ratio = 0.5;

    Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
    if( data.empty() )
    {
        printf("ERROR: File %s can not be read\n", filename.c_str());
        return 0;
    }

    data->setTrainTestSplitRatio(train_test_split_ratio);
    std::cout << "Test/Train: " << data->getNTestSamples() << "/" << data->getNTrainSamples();

    printf("======DTREE=====\n");
    Ptr<DTrees> dtree = DTrees::create();
    //树的最大可能深度---max depth 
    dtree->setMaxDepth(10);
    //每个节点最小的样本数量---min sample count
    dtree->setMinSampleCount(2);
    //回归树的终止标准---regression accuracy: N/A here 
    dtree->setRegressionAccuracy(0);
    dtree->setUseSurrogates(false);
    //最大的分类类别(use sub-optimal algorithm for larger numbers) 
    dtree->setMaxCategories(16);
    //如果cvfolds>1然后修剪决策树算法建立使用K-fold交叉验证程序,其中k是等于cvfolds。默认值是10。
    dtree->setCVFolds(0);
    //如果为真,更加严格的修剪树,使树更紧凑use 1SE rule => smaller tree  
    dtree->setUse1SERule(false);
    //如果为真,则剪枝分支将从树上物理移除。
    dtree->setTruncatePrunedTree(false);
    //先验类概率的数组,按类标记值排序。默认值空Mat
    dtree->setPriors(Mat());
    train_and_print_errs(dtree, data);

    if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
    {
        printf("======BOOST=====\n");
        Ptr<Boost> boost = Boost::create();
        boost->setBoostType(Boost::GENTLE);
        boost->setWeakCount(100);
        boost->setWeightTrimRate(0.95);
        boost->setMaxDepth(2);
        boost->setUseSurrogates(false);
        boost->setPriors(Mat());
        train_and_print_errs(boost, data);
    }

    printf("======RTREES=====\n");
    Ptr<RTrees> rtrees = RTrees::create();
    rtrees->setMaxDepth(10);
    rtrees->setMinSampleCount(2);
    rtrees->setRegressionAccuracy(0);
    rtrees->setUseSurrogates(false);
    rtrees->setMaxCategories(16);
    rtrees->setPriors(Mat());
    //返回重要变量的数组
    rtrees->setCalculateVarImportance(true);
    //在每个树节点上随机选择的特征子集的大小,用于寻找最佳分割点。如果将其设置为0,那么大小将被设置为特征总数的平方根。默认值是0。
    rtrees->setActiveVarCount(0);
    rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
    train_and_print_errs(rtrees, data);
    cv::Mat ref_labels = data->getClassLabels();
    cv::Mat test_data = data->getTestSampleIdx();
    cv::Mat predict_labels;
    rtrees->predict(data->getSamples(), predict_labels);

    cv::Mat variable_importance = rtrees->getVarImportance();
    std::cout << "Estimated variable importance" << std::endl;
    for (int i = 0; i < variable_importance.rows; i++) {
        std::cout << "Variable " << i << ": " << variable_importance.at<float>(i, 0) << std::endl;
    }
    return 0;
}
  1. 可执行代码参考博文:OpenCV3.0 决策树的使用
  2. 代码详解参考博文:OpenCv中基于决策树的分类任务代码解读
  3. 代码分析参考博文:Opencv2.4.9源码分析——Decision Trees
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值