Opencv:
setMaxCategories/getMaxCategories函数:设置/获取最大的类别数,默认值为10;
setMaxDepth/getMaxDepth函数:设置/获取树的最大深度,默认值为INT_MAX;
setMinSampleCount/getMinSampleCount函数:设置/获取最小训练样本数,默认值为10;
setCVFolds/getCVFolds函数:设置/获取CVFolds(thenumber of cross-validation folds)值,默认值为10,如果此值大于1,用于修剪构建的决策树;
setUseSurrogates/getUseSurrogates函数:设置/获取是否使用surrogatesplits方法,默认值为false;
setUse1SERule/getUse1SERule函数:设置/获取是否使用1-SE规则,默认值为true;
setTruncatePrunedTree/getTruncatedTree函数:设置/获取是否进行剪枝后移除操作,默认值为true;
setRegressionAccuracy/getRegressionAccuracy函数:设置/获取回归时用于终止的标准,默认值为0.01;
setPriors/getPriors函数:设置/获取先验概率数值,用于调整决策树的偏好,默认值为空的Mat;
getRoots函数:获取根节点索引;
getNodes函数:获取所有节点索引;
getSplits函数:获取所有拆分索引;
getSubsets函数:获取分类拆分的所有bitsets
示例代码:
// ./decision_tree ../air.csv
#include "opencv2/ml/ml.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/core/utility.hpp"
#include <stdio.h>
#include <string>
#include <map>
#include <vector>
#include <iostream>
using namespace std;
using namespace cv;
using namespace cv::ml;
static void help()
{
printf(
"\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
"Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
"where -r <response_column> specified the 0-based index of the response (0 by default)\n"
"-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
"<csv filename> is the name of training data file in comma-separated value format\n\n");
}
static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
{
bool ok = model->train(data);
if( !ok )
{
printf("Training failed\n");
}
else
{
printf( "train error: %f\n", model->calcError(data, false, noArray()) );
printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
}
}
int main(int argc, char** argv)
{
if(argc < 2)
{
help();
return 0;
}
const char* filename = 0;
int response_idx = 0;
std::string typespec;
for(int i = 1; i < argc; i++)
{
if(strcmp(argv[i], "-r") == 0)
sscanf(argv[++i], "%d", &response_idx);
else if(strcmp(argv[i], "-ts") == 0)
typespec = argv[++i];
else if(argv[i][0] != '-' )
filename = argv[i];
else
{
printf("Error. Invalid option %s\n", argv[i]);
help();
return -1;
}
}
printf("\nReading in %s...\n\n",filename);
const double train_test_split_ratio = 0.7;
Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
if( data.empty() )
{
printf("ERROR: File %s can not be read\n", filename);
return 0;
}
data->setTrainTestSplitRatio(train_test_split_ratio);
printf("======DTREE=====\n");
int depth =13;
for (int i=0;i<depth;i++){
Ptr<DTrees> dtree = DTrees::create();
dtree->setMaxDepth(i);
dtree->setMinSampleCount(2);
dtree->setRegressionAccuracy(0);
dtree->setUseSurrogates(false);
dtree->setMaxCategories(3);
dtree->setCVFolds(1);
dtree->setUse1SERule(true);
dtree->setTruncatePrunedTree(true);
dtree->setPriors(Mat());
cout<<"depth "<<i<<" "<<endl;
train_and_print_errs(dtree, data);
dtree->save("dtree_result.xml");
}
if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
{
printf("======BOOST=====\n");
Ptr<Boost> boost = Boost::create();
boost->setBoostType(Boost::GENTLE);
boost->setWeakCount(100);
boost->setWeightTrimRate(0.95);
boost->setMaxDepth(2);
boost->setUseSurrogates(false);
boost->setPriors(Mat());
train_and_print_errs(boost, data);
}
printf("======RTREES=====\n");
for (int i=0;i<depth;i++){
Ptr<RTrees> rtrees = RTrees::create();
rtrees->setMaxDepth(10);
rtrees->setMinSampleCount(2);
rtrees->setRegressionAccuracy(0);
rtrees->setUseSurrogates(false);
rtrees->setMaxCategories(3);
rtrees->setPriors(Mat());
rtrees->setCalculateVarImportance(false);
rtrees->setActiveVarCount(0);
rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
cout<<"depth "<<i<<" "<<endl;
train_and_print_errs(rtrees, data);
}
std::cout << "======TEST====="<<std::endl;
cv::Ptr<cv::ml::DTrees> dtree2 =cv::ml::DTrees::load("dtree_result.xml");
std::vector<float>testVec;
testVec.push_back(0.10111413);
testVec.push_back(0.147943);
testVec.push_back(0.1385576);
testVec.push_back(0.35972223);
float resultKind = dtree2->predict(testVec);
std::cout << "label 2,pred "<<resultKind<<std::endl;
return 0;
}
CMakeLists:
cmake_minimum_required(VERSION 3.7)
project(decision_tree)
set(CMAKE_BUILD_TYPE Release)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ")
set(CMAKE_CXX_FLAGS "-std=c++0x -Wno-deprecated")
#set(CMAKE_CXX_STANDARD 11)
#option(NCNN_OPENMP "openmp support" OFF)
FIND_PACKAGE(OpenMP)
if(OPENMP_FOUND)
message("OPENMP FOUND")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()
set(OPENCV_INCLUDE_DIR /data/opencv-3.3.1/include)
set(OPENCV_LIBRARY_DIR /data/opencv-3.3.1/build/lib)
set(OPENCV_LIBS -lopencv_core -lopencv_features2d -lopencv_highgui -lopencv_imgproc -lopencv_ml -lopencv_imgcodecs)
find_package(OpenCV REQUIRED)
include_directories(${OpenCV_INCLUDE_DIRS})
include_directories(/usr/include)
add_executable(decision_tree decision_tree.cpp)
target_link_libraries(decision_tree dl ${OpenCV_LIBS} )
air.csv数据格式(label,feature1,feature2,feature3,feature4):
1.0,0.064481236,0.13365328,0.17795815,0.38472223
2.0,0.07460991,0.13368088,0.12907954,0.3875
2.0,0.09277384,0.18969564,0.14788848,0.49861112
2.0,0.16025226,0.20588656,0.10510066,0.4722222
3.0,0.18401018,0.27592123,0.119767964,0.4236111
2.0,0.07291968,0.13746512,0.14516687,0.4486111
2.0,0.07614595,0.12724264,0.14924836,0.49861112
3.0,0.12498321,0.1986561,0.11265315,0.4097222
输出结果:
======DTREE=====
depth 0
train error: 46.975090
test error: 39.166668
depth 1
train error: 38.078293
test error: 37.500000
depth 2
train error: 34.163700
test error: 44.166668
depth 3
train error: 27.758007
test error: 43.333332
depth 4
train error: 22.775801
test error: 46.666668
depth 5
train error: 19.217081
test error: 43.333332
depth 6
train error: 14.234876
test error: 47.500000
depth 7
train error: 10.320285
test error: 45.000000
depth 8
train error: 7.473310
test error: 50.833332
depth 9
train error: 6.049822
test error: 49.166668
depth 10
train error: 3.914591
test error: 49.166668
depth 11
train error: 3.202847
test error: 50.000000
depth 12
train error: 2.846975
test error: 50.000000
======RTREES=====
depth 0
train error: 0.711744
test error: 36.666668
depth 1
train error: 0.711744
test error: 36.666668
depth 2
train error: 0.711744
test error: 36.666668
depth 3
train error: 0.711744
test error: 36.666668
depth 4
train error: 0.711744
test error: 36.666668
depth 5
train error: 0.711744
test error: 36.666668
depth 6
train error: 0.711744
test error: 36.666668
depth 7
train error: 0.711744
test error: 36.666668
depth 8
train error: 0.711744
test error: 36.666668
depth 9
train error: 0.711744
test error: 36.666668
depth 10
train error: 0.711744
test error: 36.666668
depth 11
train error: 0.711744
test error: 36.666668
depth 12
train error: 0.711744
test error: 36.666668
======TEST=====
label 2,pred 2
Sklearn:
示例代码:
#!/user/bin/env python
#-*- coding:utf-8 -*-
from sklearn.tree import export_graphviz
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import pydot
import numpy as np
#dot -Tpdf tree.dot -o tree.pdf
with open("air.csv","r",encoding="utf-8") as f:
lines=f.readlines()
label=[]
data=[]
for line in lines:
data.append(line.rstrip("\n").split(",")[1:])
label.append(line.rstrip("\n").split(",")[0])
data=np.asarray(data)
label=np.asarray(label)
#参数random_state是指随机生成器,0表示函数输出是固定不变的
X_train,X_test,y_train,y_test = train_test_split(data, label,random_state=42,test_size=0.2)
for depth in list(range(1,16)):
tree = DecisionTreeClassifier(max_depth=depth,random_state=0)
tree.fit(X_train,y_train)
print('tree depth:{} Train score:{:.3f} Test score:{:.3f}'.format(depth,tree.score(X_train,y_train),tree.score(X_test,y_test)))
print (classification_report(tree.predict(X_test),y_test,target_names=['score1','score2','score3']))
#生成可视化图
export_graphviz(tree,out_file="tree.dot",class_names=['score1','score2','score3'],feature_names=['wind','snow','rain','sunny'],impurity=False,filled=True)
#展示可视化图
(graph,) = pydot.graph_from_dot_file('tree.dot')
graph.write_png('./tree.png')
输出结果:
tree depth:1 Train score:0.641 Test score:0.543
tree depth:2 Train score:0.675 Test score:0.543
tree depth:3 Train score:0.719 Test score:0.543
tree depth:4 Train score:0.775 Test score:0.543
tree depth:5 Train score:0.803 Test score:0.494
tree depth:6 Train score:0.847 Test score:0.605
tree depth:7 Train score:0.891 Test score:0.580
tree depth:8 Train score:0.925 Test score:0.506
tree depth:9 Train score:0.969 Test score:0.519
tree depth:10 Train score:0.988 Test score:0.494
tree depth:11 Train score:0.997 Test score:0.506
tree depth:12 Train score:1.000 Test score:0.506
tree depth:13 Train score:1.000 Test score:0.506
tree depth:14 Train score:1.000 Test score:0.506
tree depth:15 Train score:1.000 Test score:0.506
precision recall f1-score support
score1 0.61 0.44 0.51 25
score2 0.56 0.53 0.55 45
score3 0.30 0.55 0.39 11
accuracy 0.51 81
macro avg 0.49 0.51 0.48 81
weighted avg 0.54 0.51 0.51 81
生成的决策树:
#dot -Tpdf tree.dot -o tree.pdf