决策树-DecisionTree

 

Opencv:

setMaxCategories/getMaxCategories函数:设置/获取最大的类别数,默认值为10;

setMaxDepth/getMaxDepth函数:设置/获取树的最大深度,默认值为INT_MAX;

setMinSampleCount/getMinSampleCount函数:设置/获取最小训练样本数,默认值为10;

setCVFolds/getCVFolds函数:设置/获取CVFolds(thenumber of cross-validation folds)值,默认值为10,如果此值大于1,用于修剪构建的决策树;

setUseSurrogates/getUseSurrogates函数:设置/获取是否使用surrogatesplits方法,默认值为false;

setUse1SERule/getUse1SERule函数:设置/获取是否使用1-SE规则,默认值为true;

setTruncatePrunedTree/getTruncatedTree函数:设置/获取是否进行剪枝后移除操作,默认值为true;

setRegressionAccuracy/getRegressionAccuracy函数:设置/获取回归时用于终止的标准,默认值为0.01;

setPriors/getPriors函数:设置/获取先验概率数值,用于调整决策树的偏好,默认值为空的Mat;

getRoots函数:获取根节点索引;

getNodes函数:获取所有节点索引;

getSplits函数:获取所有拆分索引;

getSubsets函数:获取分类拆分的所有bitsets

 

 

示例代码:

//    ./decision_tree ../air.csv
#include "opencv2/ml/ml.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/core/utility.hpp"
#include <stdio.h>
#include <string>
#include <map>
#include <vector>
#include <iostream>
using namespace std;
using namespace cv;
using namespace cv::ml;

static void help()
{
    printf(
        "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
        "Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
        "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
        "-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
        "<csv filename> is the name of training data file in comma-separated value format\n\n");
}

static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
{
    bool ok = model->train(data);
    if( !ok )
    {
        printf("Training failed\n");
    }
    else
    {
        printf( "train error: %f\n", model->calcError(data, false, noArray()) );
        printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
    }
}

int main(int argc, char** argv)
{
    if(argc < 2)
    {
        help();
        return 0;
    }
    const char* filename = 0;
    int response_idx = 0;
    std::string typespec;

    for(int i = 1; i < argc; i++)
    {
        if(strcmp(argv[i], "-r") == 0)
            sscanf(argv[++i], "%d", &response_idx);
        else if(strcmp(argv[i], "-ts") == 0)
            typespec = argv[++i];
        else if(argv[i][0] != '-' )
            filename = argv[i];
        else
        {
            printf("Error. Invalid option %s\n", argv[i]);
            help();
            return -1;
        }
    }

    printf("\nReading in %s...\n\n",filename);
    const double train_test_split_ratio = 0.7;

    Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);

    if( data.empty() )
    {
        printf("ERROR: File %s can not be read\n", filename);
        return 0;
    }

    data->setTrainTestSplitRatio(train_test_split_ratio);

    printf("======DTREE=====\n");
	int depth =13;
	for (int i=0;i<depth;i++){
	    Ptr<DTrees> dtree = DTrees::create();
    	dtree->setMaxDepth(i);
    	dtree->setMinSampleCount(2);
    	dtree->setRegressionAccuracy(0);
    	dtree->setUseSurrogates(false);
    	dtree->setMaxCategories(3);
    	dtree->setCVFolds(1);
    	dtree->setUse1SERule(true);
    	dtree->setTruncatePrunedTree(true);
    	dtree->setPriors(Mat());

		cout<<"depth "<<i<<" "<<endl;
    	train_and_print_errs(dtree, data);
    	dtree->save("dtree_result.xml");
	}



    if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
    {
        printf("======BOOST=====\n");
        Ptr<Boost> boost = Boost::create();
        boost->setBoostType(Boost::GENTLE);
        boost->setWeakCount(100);
        boost->setWeightTrimRate(0.95);
        boost->setMaxDepth(2);
        boost->setUseSurrogates(false);
        boost->setPriors(Mat());
        train_and_print_errs(boost, data);
    }

    printf("======RTREES=====\n");
	for (int i=0;i<depth;i++){
    	Ptr<RTrees> rtrees = RTrees::create();
    	rtrees->setMaxDepth(10);
    	rtrees->setMinSampleCount(2);
    	rtrees->setRegressionAccuracy(0);
    	rtrees->setUseSurrogates(false);
    	rtrees->setMaxCategories(3);
    	rtrees->setPriors(Mat());
    	rtrees->setCalculateVarImportance(false);
    	rtrees->setActiveVarCount(0);
    	rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
		
		cout<<"depth "<<i<<" "<<endl;
    	train_and_print_errs(rtrees, data);
	}




    std::cout << "======TEST====="<<std::endl;
	cv::Ptr<cv::ml::DTrees> dtree2 =cv::ml::DTrees::load("dtree_result.xml");
    std::vector<float>testVec;
    testVec.push_back(0.10111413);
    testVec.push_back(0.147943);
    testVec.push_back(0.1385576);
    testVec.push_back(0.35972223);
    float resultKind = dtree2->predict(testVec);
    std::cout << "label 2,pred "<<resultKind<<std::endl;
    return 0;
}

CMakeLists:

cmake_minimum_required(VERSION 3.7)
project(decision_tree)


set(CMAKE_BUILD_TYPE Release)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ")
set(CMAKE_CXX_FLAGS "-std=c++0x -Wno-deprecated")
#set(CMAKE_CXX_STANDARD 11)
#option(NCNN_OPENMP "openmp support" OFF)
FIND_PACKAGE(OpenMP)
if(OPENMP_FOUND)
    message("OPENMP FOUND")
    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
    set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()


set(OPENCV_INCLUDE_DIR /data/opencv-3.3.1/include)
set(OPENCV_LIBRARY_DIR /data/opencv-3.3.1/build/lib)
set(OPENCV_LIBS -lopencv_core -lopencv_features2d -lopencv_highgui -lopencv_imgproc -lopencv_ml -lopencv_imgcodecs)


find_package(OpenCV REQUIRED)
include_directories(${OpenCV_INCLUDE_DIRS})
include_directories(/usr/include)

add_executable(decision_tree decision_tree.cpp)
target_link_libraries(decision_tree dl ${OpenCV_LIBS} )

air.csv数据格式(label,feature1,feature2,feature3,feature4):

1.0,0.064481236,0.13365328,0.17795815,0.38472223
2.0,0.07460991,0.13368088,0.12907954,0.3875
2.0,0.09277384,0.18969564,0.14788848,0.49861112
2.0,0.16025226,0.20588656,0.10510066,0.4722222
3.0,0.18401018,0.27592123,0.119767964,0.4236111
2.0,0.07291968,0.13746512,0.14516687,0.4486111
2.0,0.07614595,0.12724264,0.14924836,0.49861112
3.0,0.12498321,0.1986561,0.11265315,0.4097222

输出结果:

======DTREE=====
depth 0 
train error: 46.975090
test error: 39.166668

depth 1 
train error: 38.078293
test error: 37.500000

depth 2 
train error: 34.163700
test error: 44.166668

depth 3 
train error: 27.758007
test error: 43.333332

depth 4 
train error: 22.775801
test error: 46.666668

depth 5 
train error: 19.217081
test error: 43.333332

depth 6 
train error: 14.234876
test error: 47.500000

depth 7 
train error: 10.320285
test error: 45.000000

depth 8 
train error: 7.473310
test error: 50.833332

depth 9 
train error: 6.049822
test error: 49.166668

depth 10 
train error: 3.914591
test error: 49.166668

depth 11 
train error: 3.202847
test error: 50.000000

depth 12 
train error: 2.846975
test error: 50.000000

======RTREES=====
depth 0 
train error: 0.711744
test error: 36.666668

depth 1 
train error: 0.711744
test error: 36.666668

depth 2 
train error: 0.711744
test error: 36.666668

depth 3 
train error: 0.711744
test error: 36.666668

depth 4 
train error: 0.711744
test error: 36.666668

depth 5 
train error: 0.711744
test error: 36.666668

depth 6 
train error: 0.711744
test error: 36.666668

depth 7 
train error: 0.711744
test error: 36.666668

depth 8 
train error: 0.711744
test error: 36.666668

depth 9 
train error: 0.711744
test error: 36.666668

depth 10 
train error: 0.711744
test error: 36.666668

depth 11 
train error: 0.711744
test error: 36.666668

depth 12 
train error: 0.711744
test error: 36.666668

======TEST=====
label 2,pred 2

 

Sklearn:

示例代码:

#!/user/bin/env python
#-*- coding:utf-8 -*-
from sklearn.tree import export_graphviz
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import pydot
import numpy as np

#dot -Tpdf tree.dot -o tree.pdf


with open("air.csv","r",encoding="utf-8") as f:
    lines=f.readlines()

label=[]
data=[]
for line in lines:
    data.append(line.rstrip("\n").split(",")[1:])
    label.append(line.rstrip("\n").split(",")[0])

data=np.asarray(data)
label=np.asarray(label)


#参数random_state是指随机生成器,0表示函数输出是固定不变的
X_train,X_test,y_train,y_test = train_test_split(data, label,random_state=42,test_size=0.2)

for depth in list(range(1,16)):
    tree = DecisionTreeClassifier(max_depth=depth,random_state=0)
    tree.fit(X_train,y_train)
    print('tree depth:{}   Train score:{:.3f}    Test score:{:.3f}'.format(depth,tree.score(X_train,y_train),tree.score(X_test,y_test)))

print (classification_report(tree.predict(X_test),y_test,target_names=['score1','score2','score3']))

#生成可视化图
export_graphviz(tree,out_file="tree.dot",class_names=['score1','score2','score3'],feature_names=['wind','snow','rain','sunny'],impurity=False,filled=True)
#展示可视化图
(graph,) = pydot.graph_from_dot_file('tree.dot')
graph.write_png('./tree.png')

输出结果:

tree depth:1   Train score:0.641    Test score:0.543
tree depth:2   Train score:0.675    Test score:0.543
tree depth:3   Train score:0.719    Test score:0.543
tree depth:4   Train score:0.775    Test score:0.543
tree depth:5   Train score:0.803    Test score:0.494
tree depth:6   Train score:0.847    Test score:0.605
tree depth:7   Train score:0.891    Test score:0.580
tree depth:8   Train score:0.925    Test score:0.506
tree depth:9   Train score:0.969    Test score:0.519
tree depth:10   Train score:0.988    Test score:0.494
tree depth:11   Train score:0.997    Test score:0.506
tree depth:12   Train score:1.000    Test score:0.506
tree depth:13   Train score:1.000    Test score:0.506
tree depth:14   Train score:1.000    Test score:0.506
tree depth:15   Train score:1.000    Test score:0.506
              precision    recall  f1-score   support

      score1       0.61      0.44      0.51        25
      score2       0.56      0.53      0.55        45
      score3       0.30      0.55      0.39        11

    accuracy                           0.51        81
   macro avg       0.49      0.51      0.48        81
weighted avg       0.54      0.51      0.51        81

生成的决策树:

#dot -Tpdf tree.dot -o tree.pdf

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值