RandomTree.h
#ifndef RANDOM_TREE_H
#define RANDOM_TREE_H
#include <opencv2/core/core.hpp>
#include <opencv2/ml/ml.hpp>
#include <string>
#include <vector>
#include <iostream>
using namespace cv;
using namespace std;
class RandomTree
{
public:
RandomTree(string name);
void set_train_data_(vector<vector<float>> &train_data, vector<int> &label);
void Train();
void Save();
float AccOnTrain();
int Predict(vector<float> &f);
private:
string name_;
Ptr<ml::TrainData> train_data_;
Ptr<ml::RTrees> forest_;
};
RandomTree::RandomTree(string name)
{
name_ = name;
forest_ = cv::ml::RTrees::create();
forest_->setMaxDepth(10); //树的最大深度
forest_->setPriors(cv::Mat());
forest_->setRegressionAccuracy(0.01); //设置回归精度
//终止标准
forest_->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER + cv::TermCriteria::EPS, 100, 0.01));
forest_->setMinSampleCount(10); //节点的最小样本数量
forest_->setUseSurrogates(false);
forest_->setMaxCategories(