关于决策树的介绍可以参考: https://blog.csdn.net/fengbingchun/article/details/78880934
CART算法的决策树的Python实现可以参考: https://blog.csdn.net/fengbingchun/article/details/78881143
这里参考 https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/ 这篇文章的原有Python实现,使用C++实现了决策树的CART算法,测试数据集是Banknote Dataset,关于Banknote Dataset的介绍可以参考: https://blog.csdn.net/fengbingchun/article/details/78624358 。
decision_tree.hpp文件内容如下:
#ifndef FBC_NN_DECISION_TREE_HPP_
#define FBC_NN_DECISION_TREE_HPP_
#include <vector>
#include <tuple>
#include <fstream>
namespace ANN {
// referecne: https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/
template<typename T>
class DecisionTree { // CART(Classification and Regression Trees)
public:
DecisionTree() = default;
~DecisionTree() { delete_tree(); }
int init(const std::vector<std::vector<T>>& data, const std::vector<T>& classes);
void set_max_depth(int max_depth) { this->max_depth = max_depth; }
int get_max_depth() const { return max_depth; }
void set_min_size(int min_size) { this->min_size = min_size; }
int get_min_size() const { return min_size; }
void train();
int save_model(const char* name) const;
int load_model(const char* name);
T predict(const std::vector<T>& data) const;
protected:
typedef std::tuple<int, T, std::vector<std::vector<std::vector<T>>>> dictionary; // index of attribute, value of attribute, groups of data
typedef std::tuple<int, int, T, T, T> row_element; // flag, index, value, class_value_left, class_value_right
typedef struct binary_tree {
dictionary dict;
T class_value_left = (T)-1.f;
T class_value_right = (T)-1.f;
binary_tree* left = nullptr;
binary_tree* right = nullptr;
} binary_tree;
// Calculate the Gini index for a split dataset
T gini_index(const std::vector<std::vector<std::vector<T>>>& groups, const std::vector<T>& classes) const;
// Select the best split point for a dataset
dictionary get_split(const std::vector<std::vector<T>>& dataset) const;
// Split a dataset based on an attribute and an attribute value
std::vector<std::vector<std::vector<T>>> test_split(int index, T value, const std::vector<std::vector<T>>& dataset) const;
// Create a terminal node value
T to_terminal(const std::vector<std::vector<T>>& group) const;
// Create child splits for a node or make terminal
void split(binary_tree* node, int depth);
// Build a decision tree
void build_tree(const std::vector<std::vector<T>>& train);
// Print a decision tree
void print_tree(const binary_tree* node, int depth = 0) const;
// Make a prediction with a decision tree
T predict(binary_tree* node, const std::vector<T>& data) const;
// calculate accuracy percentage
double accuracy_metric() const;
void delete_tree();
void delete_node(binary_tree* node);
void write_node(const binary_tree* node, std::ofstream& file) const;
void node_to_row_element(binary_tree* node, std::vector<row_element>& rows, int pos) const;
int height_of_tree(const binary_tree* node) const;
void row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos);
private:
std::vector<std::vector<T>> src_data;
binary_tree* tree = nullptr;
int samples_num = 0;
int feature_length = 0;
int classes_num = 0;
int max_depth = 10; // maximum tree depth
int min_size = 10; // minimum node records
int max_nodes = -1;
};
} // namespace ANN
#endif // FBC_NN_DECISION_TREE_HPP_
decision_tree.cpp文件内容如下:
#include "decision_tree.hpp"
#include <set>
#include <algorithm>
#include <typeinfo>
#include <iterator>
#include "common.hpp"
namespace ANN {
template<typename T>
int DecisionTree<T>::init(const std::vector<std::vector<T>>& data, const std::vector<T>& classes)
{
CHECK(data.size() != 0 && classes.size() != 0 && data[0].size() != 0);
this->samples_num = data.size();
this->classes_num = classes.size();
this->feature_length = data[0].size() -1;
for (int i = 0; i < this->samples_num; ++i) {
this->src_data.emplace_back(data[i]);
}
return 0;
}
template<typename T>
T DecisionTree<T>::gini_index(const std::vector<std::vector<std::vector<T>>>& groups, const std::vector<T>& classes) const
{
// Gini calculation for a group
// proportion = count(class_value) / count(rows)
// gini_index = (1.0 - sum(proportion * proportion)) * (group_size/total_samples)
// count all samples at split point
int instances = 0;
int group_num = groups.size();
for (int i = 0; i < group_num; ++i) {
instances += groups[i].size();
}
// sum weighted Gini index for each group
T gini = (T)0.;
for (int i = 0; i < group_num; ++i) {
int size = groups[i].size();
// avoid divide by zero
if (size == 0) continue;
T score = (T)0.;
// score the group based on the score for each class
T p = (T)0.;
for (int c = 0; c < classes.size(); ++c) {
int count = 0;
for (int t = 0; t < size; ++t) {
if (groups[i][t][this->feature_length] == classes[c]) ++count;
}
T p = (float)count / size;
score += p * p;
}
// weight the group score by its relative size
gini += (1. - score) * (float)size / instances;
}
return gini;
}
template<typename T>
std::vector<std::vector<std::vector<T>>> DecisionTree<T>::test_split(int index, T value, const std::vector<std::vector<T>>& dataset) const
{
std::vector<std::vector<std::vector<T>>> groups(2); // 0: left, 1: reight
for (int row = 0; row < dataset.size(); ++row) {
if (dataset[row][index] < value) {
groups[0].emplace_back(dataset[row]);
} else {
groups[1].emplace_back(dataset[row]);
}
}
return groups;
}
template<typename T>
std::tuple<int, T, std::vector<std::vector<std::vector<T>>>> DecisionTree<T>::get_split(const std::vector<std::vector<T>>& dataset) const
{
std::vector<T> values;
for (int i = 0; i < dataset.size(); ++i) {
values.emplace_back(dataset[i][this->feature_length]);
}
std::set<T> vals(values.cbegin(), values.cend());
std::vector<T> class_values(vals.cbegin(), vals.cend());
int b_index = 999;
T b_value = (T)999.;
T b_score = (T)999.;
std::vector<std::vector<std::vector<T>>> b_groups(2);
for (int index = 0; index < this->feature_length; ++index) {
for (int row = 0; row < dataset.size(); ++row) {
std::vector<std::vector<std::vector<T>>> groups = test_split(index, dataset[row][index], dataset);
T gini = gini_index(groups, class_values);
if (gini < b_score) {
b_index = index;
b_value = dataset[row][index];
b_score = gini;
b_groups = groups;
}
}
}
// a new node: the index of the chosen attribute, the value of that attribute by which to split and the two groups of data split by the chosen split point
return std::make_tuple(b_index, b_value, b_groups);
}
template<typename T>
T DecisionTree<T>::to_terminal(const std::vector<std::vector<T>>& group) const
{
std::vector<T> values;
for (int i = 0; i < group.size(); ++i) {
values.emplace_back(group[i][this->feature_length]);
}
std::set<T> vals(values.cbegin(), values.cend());
int max_count = -1, index = -1;
for (int i = 0; i < vals.size(); ++i) {
int count = std::count(values.cbegin(), values.cend(), *std::next(vals.cbegin(), i));
if (max_count < count) {
max_count = count;
index = i;
}
}
return *std::next(vals.cbegin(), index);
}
template<typename T>
void DecisionTree<T>::split(binary_tree* node, int depth)
{
std::vector<std::vector<T>> left = std::get<2>(node->dict)[0];
std::vector<std::vector<T>> right = std::get<2>(node->dict)[1];
std::get<2>(node->dict).clear();
// check for a no split
if (left.size() == 0 || right.size() == 0) {
for (int i = 0; i < right.size(); ++i) {
left.emplace_back(right[i]);
}
node->class_value_left = node->class_value_right = to_terminal(left);
return;
}
// check for max depth
if (depth >= max_depth) {
node->class_value_left = to_terminal(left);
node->class_value_right = to_terminal(right);
return;
}
// process left child
if (left.size() <= min_size) {
node->class_value_left = to_terminal(left);
} else {
dictionary dict = get_split(left);
node->left = new binary_tree;
node->left->dict = dict;
split(node->left, depth+1);
}
// process right child
if (right.size() <= min_size) {
node->class_value_right = to_terminal(right);
} else {
dictionary dict = get_split(right);
node->right = new binary_tree;
node->right->dict = dict;
split(node->right, depth+1);
}
}
template<typename T>
void DecisionTree<T>::build_tree(const std::vector<std::vector<T>>& train)
{
// create root node
dictionary root = get_split(train);
binary_tree* node = new binary_tree;
node->dict = root;
tree = node;
split(node, 1);
}
template<typename T>
void DecisionTree<T>::train()
{
this->max_nodes = (1 << max_depth) - 1;
build_tree(src_data);
accuracy_metric();
//binary_tree* tmp = tree;
//print_tree(tmp);
}
template<typename T>
T DecisionTree<T>::predict(const std::vector<T>& data) const
{
if (!tree) {
fprintf(stderr, "Error, tree is null\n");
return -1111.f;
}
return predict(tree, data);
}
template<typename T>
T DecisionTree<T>::predict(binary_tree* node, const std::vector<T>& data) const
{
if (data[std::get<0>(node->dict)] < std::get<1>(node->dict)) {
if (node->left) {
return predict(node->left, data);
} else {
return node->class_value_left;
}
} else {
if (node->right) {
return predict(node->right, data);
} else {
return node->class_value_right;
}
}
}
template<typename T>
int DecisionTree<T>::save_model(const char* name) const
{
std::ofstream file(name, std::ios::out);
if (!file.is_open()) {
fprintf(stderr, "open file fail: %s\n", name);
return -1;
}
file<<max_depth<<","<<min_size<<std::endl;
binary_tree* tmp = tree;
int depth = height_of_tree(tmp);
CHECK(max_depth == depth);
tmp = tree;
write_node(tmp, file);
file.close();
return 0;
}
template<typename T>
void DecisionTree<T>::write_node(const binary_tree* node, std::ofstream& file) const
{
/*if (!node) return;
write_node(node->left, file);
file<<std::get<0>(node->dict)<<","<<std::get<1>(node->dict)<<","<<node->class_value_left<<","<<node->class_value_right<<std::endl;
write_node(node->right, file);*/
//typedef std::tuple<int, int, T, T, T> row; // flag, index, value, class_value_left, class_value_right
std::vector<row_element> vec(this->max_nodes, std::make_tuple(-1, -1, (T)-1.f, (T)-1.f, (T)-1.f));
binary_tree* tmp = const_cast<binary_tree*>(node);
node_to_row_element(tmp, vec, 0);
for (const auto& row : vec) {
file<<std::get<0>(row)<<","<<std::get<1>(row)<<","<<std::get<2>(row)<<","<<std::get<3>(row)<<","<<std::get<4>(row)<<std::endl;
}
}
template<typename T>
void DecisionTree<T>::node_to_row_element(binary_tree* node, std::vector<row_element>& rows, int pos) const
{
if (!node) return;
rows[pos] = std::make_tuple(0, std::get<0>(node->dict), std::get<1>(node->dict), node->class_value_left, node->class_value_right); // 0: have node, -1: no node
if (node->left) node_to_row_element(node->left, rows, 2*pos+1);
if (node->right) node_to_row_element(node->right, rows, 2*pos+2);
}
template<typename T>
int DecisionTree<T>::height_of_tree(const binary_tree* node) const
{
if (!node)
return 0;
else
return std::max(height_of_tree(node->left), height_of_tree(node->right)) + 1;
}
template<typename T>
int DecisionTree<T>::load_model(const char* name)
{
std::ifstream file(name, std::ios::in);
if (!file.is_open()) {
fprintf(stderr, "open file fail: %s\n", name);
return -1;
}
std::string line, cell;
std::getline(file, line);
std::stringstream line_stream(line);
std::vector<int> vec;
int count = 0;
while (std::getline(line_stream, cell, ',')) {
vec.emplace_back(std::stoi(cell));
}
CHECK(vec.size() == 2);
max_depth = vec[0];
min_size = vec[1];
max_nodes = (1 << max_depth) - 1;
std::vector<row_element> rows(max_nodes);
if (typeid(float).name() == typeid(T).name()) {
while (std::getline(file, line)) {
std::stringstream line_stream2(line);
std::vector<T> vec2;
while(std::getline(line_stream2, cell, ',')) {
vec2.emplace_back(std::stof(cell));
}
CHECK(vec2.size() == 5);
rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec2[4]);
//fprintf(stderr, "%d, %d, %f, %f, %f\n", std::get<0>(rows[count]), std::get<1>(rows[count]), std::get<2>(rows[count]), std::get<3>(rows[count]), std::get<4>(rows[count]));
++count;
}
} else { // double
while (std::getline(file, line)) {
std::stringstream line_stream2(line);
std::vector<T> vec2;
while(std::getline(line_stream2, cell, ',')) {
vec2.emplace_back(std::stod(cell));
}
CHECK(vec2.size() == 5);
rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec[4]);
++count;
}
}
CHECK(max_nodes == count);
CHECK(std::get<0>(rows[0]) != -1);
binary_tree* tmp = new binary_tree;
std::vector<std::vector<std::vector<T>>> dump;
tmp->dict = std::make_tuple(std::get<1>(rows[0]), std::get<2>(rows[0]), dump);
tmp->class_value_left = std::get<3>(rows[0]);
tmp->class_value_right = std::get<4>(rows[0]);
tree = tmp;
row_element_to_node(tmp, rows, max_nodes, 0);
file.close();
return 0;
}
template<typename T>
void DecisionTree<T>::row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos)
{
if (!node || n == 0) return;
int new_pos = 2 * pos + 1;
if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {
node->left = new binary_tree;
std::vector<std::vector<std::vector<T>>> dump;
node->left->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump);
node->left->class_value_left = std::get<3>(rows[new_pos]);
node->left->class_value_right = std::get<4>(rows[new_pos]);
row_element_to_node(node->left, rows, n, new_pos);
}
new_pos = 2 * pos + 2;
if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {
node->right = new binary_tree;
std::vector<std::vector<std::vector<T>>> dump;
node->right->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump);
node->right->class_value_left = std::get<3>(rows[new_pos]);
node->right->class_value_right = std::get<4>(rows[new_pos]);
row_element_to_node(node->right, rows, n, new_pos);
}
}
template<typename T>
void DecisionTree<T>::delete_tree()
{
delete_node(tree);
}
template<typename T>
void DecisionTree<T>::delete_node(binary_tree* node)
{
if (node->left) delete_node(node->left);
if (node->right) delete_node(node->right);
delete node;
}
template<typename T>
double DecisionTree<T>::accuracy_metric() const
{
int correct = 0;
for (int i = 0; i < this->samples_num; ++i) {
T predicted = predict(tree, src_data[i]);
if (predicted == src_data[i][this->feature_length])
++correct;
}
double accuracy = correct / (double)samples_num * 100.;
fprintf(stdout, "train accuracy: %f\n", accuracy);
return accuracy;
}
template<typename T>
void DecisionTree<T>::print_tree(const binary_tree* node, int depth) const
{
if (node) {
std::string blank = " ";
for (int i = 0; i < depth; ++i) blank += blank;
fprintf(stdout, "%s[X%d < %.3f]\n", blank.c_str(), std::get<0>(node->dict)+1, std::get<1>(node->dict));
if (!node->left || !node->right)
blank += blank;
if (!node->left)
fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_left);
else
print_tree(node->left, depth+1);
if (!node->right)
fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_right);
else
print_tree(node->right, depth+1);
}
}
template class DecisionTree<float>;
template class DecisionTree<double>;
} // namespace ANN
对外提供两个接口,一个是test_decision_tree_train用于训练,一个是test_decision_tree_predict用于测试,其code如下:
// =============================== decision tree ==============================
int test_decision_tree_train()
{
// small dataset test
/*const std::vector<std::vector<float>> data{ { 2.771244718f, 1.784783929f, 0.f },
{ 1.728571309f, 1.169761413f, 0.f },
{ 3.678319846f, 2.81281357f, 0.f },
{ 3.961043357f, 2.61995032f, 0.f },
{ 2.999208922f, 2.209014212f, 0.f },
{ 7.497545867f, 3.162953546f, 1.f },
{ 9.00220326f, 3.339047188f, 1.f },
{ 7.444542326f, 0.476683375f, 1.f },
{ 10.12493903f, 3.234550982f, 1.f },
{ 6.642287351f, 3.319983761f, 1.f } };
const std::vector<float> classes{ 0.f, 1.f };
ANN::DecisionTree<float> dt;
dt.init(data, classes);
dt.set_max_depth(3);
dt.set_min_size(1);
dt.train();
#ifdef _MSC_VER
const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
const char* model_name = "data/decision_tree.model";
#endif
dt.save_model(model_name);
ANN::DecisionTree<float> dt2;
dt2.load_model(model_name);
const std::vector<std::vector<float>> test{{0.6f, 1.9f, 0.f}, {9.7f, 4.3f, 1.f}};
for (const auto& row : test) {
float ret = dt2.predict(row);
fprintf(stdout, "predict result: %.1f, actural value: %.1f\n", ret, row[2]);
} */
// banknote authentication dataset
#ifdef _MSC_VER
const char* file_name = "E:/GitCode/NN_Test/data/database/BacknoteDataset/data_banknote_authentication.txt";
#else
const char* file_name = "data/database/BacknoteDataset/data_banknote_authentication.txt";
#endif
std::vector<std::vector<float>> data;
int ret = read_txt_file<float>(file_name, data, ',', 1372, 5);
if (ret != 0) {
fprintf(stderr, "parse txt file fail: %s\n", file_name);
return -1;
}
//fprintf(stdout, "data size: rows: %d\n", data.size());
const std::vector<float> classes{ 0.f, 1.f };
ANN::DecisionTree<float> dt;
dt.init(data, classes);
dt.set_max_depth(6);
dt.set_min_size(10);
dt.train();
#ifdef _MSC_VER
const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
const char* model_name = "data/decision_tree.model";
#endif
dt.save_model(model_name);
return 0;
}
int test_decision_tree_predict()
{
#ifdef _MSC_VER
const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
const char* model_name = "data/decision_tree.model";
#endif
ANN::DecisionTree<float> dt;
dt.load_model(model_name);
int max_depth = dt.get_max_depth();
int min_size = dt.get_min_size();
fprintf(stdout, "max_depth: %d, min_size: %d\n", max_depth, min_size);
std::vector<std::vector<float>> test {{-2.5526,-7.3625,6.9255,-0.66811,1},
{-4.5531,-12.5854,15.4417,-1.4983,1},
{4.0948,-2.9674,2.3689,0.75429,0},
{-1.0401,9.3987,0.85998,-5.3336,0},
{1.0637,3.6957,-4.1594,-1.9379,1}};
for (const auto& row : test) {
float ret = dt.predict(row);
fprintf(stdout, "predict result: %.1f, actual value: %.1f\n", ret, row[4]);
}
return 0;
}
训练接口执行结果如下:
测试接口执行结果如下:
训练时生成的模型decison_tree.model内容如下:
6,10
0,0,0.3223,-1,-1
0,1,7.6274,-1,-1
0,2,-4.3839,-1,-1
0,0,-0.39816,-1,-1
0,0,-4.2859,-1,-1
0,0,4.2164,-1,0
0,0,1.594,-1,-1
0,2,6.2204,-1,-1
0,1,5.8974,-1,-1
0,0,-5.4901,-1,1
0,0,-1.5768,-1,-1
0,0,0.47368,1,-1
-1,-1,-1,-1,-1
0,2,-2.2718,-1,-1
0,0,2.0421,-1,-1
0,1,7.3273,-1,1
0,1,-4.6062,-1,-1
0,2,3.1143,-1,-1
0,0,0.049175,0,0
0,0,-6.2003,1,1
-1,-1,-1,-1,-1
0,0,-2.7419,0,-1
0,0,-1.5768,0,0
-1,-1,-1,-1,-1
0,0,0.47368,1,1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,1,7.6377,-1,0
0,3,0.097399,-1,-1
0,2,-2.3386,1,-1
0,0,3.6216,-1,-1
0,0,-1.3971,1,1
-1,-1,-1,-1,-1
0,0,-1.6677,1,1
0,0,-1.7781,0,0
0,0,-0.36506,1,1
0,3,1.547,0,1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,0,-2.7419,0,0
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,0,1.0552,1,1
-1,-1,-1,-1,-1
0,0,0.4339,0,0
0,2,2.0013,1,0
-1,-1,-1,-1,-1
0,0,1.8993,0,0
0,0,3.4566,0,0
0,0,3.6216,0,0