决策树的C++实现(CART)

关于决策树的介绍可以参考: https://blog.csdn.net/fengbingchun/article/details/78880934

CART算法的决策树的Python实现可以参考: https://blog.csdn.net/fengbingchun/article/details/78881143

这里参考 https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/ 这篇文章的原有Python实现,使用C++实现了决策树的CART算法,测试数据集是Banknote Dataset,关于Banknote Dataset的介绍可以参考: https://blog.csdn.net/fengbingchun/article/details/78624358 。

decision_tree.hpp文件内容如下:

#ifndef FBC_NN_DECISION_TREE_HPP_
#define FBC_NN_DECISION_TREE_HPP_

#include <vector>
#include <tuple>
#include <fstream>

namespace ANN {
// referecne: https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/

template<typename T>
class DecisionTree { // CART(Classification and Regression Trees)
public:
	DecisionTree() = default;
	~DecisionTree() { delete_tree(); }
	int init(const std::vector<std::vector<T>>& data, const std::vector<T>& classes);
	void set_max_depth(int max_depth) { this->max_depth = max_depth; }
	int get_max_depth() const { return max_depth; }
	void set_min_size(int min_size) { this->min_size = min_size; }
	int get_min_size() const { return min_size; }
	void train();
	int save_model(const char* name) const;
	int load_model(const char* name);
	T predict(const std::vector<T>& data) const;

protected:
	typedef std::tuple<int, T, std::vector<std::vector<std::vector<T>>>> dictionary; // index of attribute, value of attribute, groups of data
	typedef std::tuple<int, int, T, T, T> row_element; // flag, index, value, class_value_left, class_value_right
	typedef struct binary_tree {
		dictionary dict;
		T class_value_left = (T)-1.f;
		T class_value_right = (T)-1.f;
		binary_tree* left = nullptr;
		binary_tree* right = nullptr;
	} binary_tree;

	// Calculate the Gini index for a split dataset
	T gini_index(const std::vector<std::vector<std::vector<T>>>& groups, const std::vector<T>& classes) const;
	// Select the best split point for a dataset
	dictionary get_split(const std::vector<std::vector<T>>& dataset) const;
	// Split a dataset based on an attribute and an attribute value
	std::vector<std::vector<std::vector<T>>> test_split(int index, T value, const std::vector<std::vector<T>>& dataset) const;
	// Create a terminal node value
	T to_terminal(const std::vector<std::vector<T>>& group) const;
	// Create child splits for a node or make terminal
	void split(binary_tree* node, int depth);
	// Build a decision tree
	void build_tree(const std::vector<std::vector<T>>& train);
	// Print a decision tree
	void print_tree(const binary_tree* node, int depth = 0) const;
	// Make a prediction with a decision tree
	T predict(binary_tree* node, const std::vector<T>& data) const;
	// calculate accuracy percentage
	double accuracy_metric() const;
	void delete_tree();
	void delete_node(binary_tree* node);
	void write_node(const binary_tree* node, std::ofstream& file) const;
	void node_to_row_element(binary_tree* node, std::vector<row_element>& rows, int pos) const;
	int height_of_tree(const binary_tree* node) const;
	void row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos);

private:
	std::vector<std::vector<T>> src_data;
	binary_tree* tree = nullptr;
	int samples_num = 0;
	int feature_length = 0;
	int classes_num = 0;
	int max_depth = 10; // maximum tree depth
	int min_size = 10; // minimum node records
	int max_nodes = -1;
};

} // namespace ANN


#endif // FBC_NN_DECISION_TREE_HPP_

decision_tree.cpp文件内容如下:

#include "decision_tree.hpp"
#include <set>
#include <algorithm>
#include <typeinfo>
#include <iterator>
#include "common.hpp"

namespace ANN {

template<typename T>
int DecisionTree<T>::init(const std::vector<std::vector<T>>& data, const std::vector<T>& classes)
{
	CHECK(data.size() != 0 && classes.size() != 0 && data[0].size() != 0);

	this->samples_num = data.size();
	this->classes_num = classes.size();
	this->feature_length = data[0].size() -1;

	for (int i = 0; i < this->samples_num; ++i) {
		this->src_data.emplace_back(data[i]);
	}

	return 0;
}

template<typename T>
T DecisionTree<T>::gini_index(const std::vector<std::vector<std::vector<T>>>& groups, const std::vector<T>& classes) const
{
	// Gini calculation for a group
	// proportion = count(class_value) / count(rows)
	// gini_index = (1.0 - sum(proportion * proportion)) * (group_size/total_samples)

	// count all samples at split point
	int instances = 0;
	int group_num = groups.size();
	for (int i = 0; i < group_num; ++i) {
		instances += groups[i].size();
	}

	// sum weighted Gini index for each group
	T gini = (T)0.;
	for (int i = 0; i < group_num; ++i) {
		int size = groups[i].size();
		// avoid divide by zero
		if (size == 0) continue;
		T score = (T)0.;

		// score the group based on the score for each class
		T p = (T)0.;
		for (int c = 0; c < classes.size(); ++c) {
			int count = 0;
			for (int t = 0; t < size; ++t) {
				if (groups[i][t][this->feature_length] == classes[c]) ++count;
			}
			T p = (float)count / size;
			score += p * p;
		}

		// weight the group score by its relative size
		gini += (1. - score) * (float)size / instances;
	}

	return gini;
}

template<typename T>
std::vector<std::vector<std::vector<T>>> DecisionTree<T>::test_split(int index, T value, const std::vector<std::vector<T>>& dataset) const
{
	std::vector<std::vector<std::vector<T>>> groups(2); // 0: left, 1: reight

	for (int row = 0; row < dataset.size(); ++row) {
		if (dataset[row][index] < value) {
			groups[0].emplace_back(dataset[row]);
		} else {
			groups[1].emplace_back(dataset[row]);
		}
	}

	return groups;
}

template<typename T>
std::tuple<int, T, std::vector<std::vector<std::vector<T>>>> DecisionTree<T>::get_split(const std::vector<std::vector<T>>& dataset) const
{
	std::vector<T> values;
	for (int i = 0; i < dataset.size(); ++i) {
		values.emplace_back(dataset[i][this->feature_length]);
	}

	std::set<T> vals(values.cbegin(), values.cend());
	std::vector<T> class_values(vals.cbegin(), vals.cend());

	int b_index = 999;
	T b_value = (T)999.;
	T b_score = (T)999.;
	std::vector<std::vector<std::vector<T>>> b_groups(2);

	for (int index = 0; index < this->feature_length; ++index) {
		for (int row = 0; row < dataset.size(); ++row) {
			std::vector<std::vector<std::vector<T>>> groups = test_split(index, dataset[row][index], dataset);
			T gini = gini_index(groups, class_values);

			if (gini < b_score) {
				b_index = index;
				b_value = dataset[row][index];
				b_score = gini;
				b_groups = groups;
			}
		}
	}

	// a new node: the index of the chosen attribute, the value of that attribute by which to split and the two groups of data split by the chosen split point
	return std::make_tuple(b_index, b_value, b_groups);
}

template<typename T>
T DecisionTree<T>::to_terminal(const std::vector<std::vector<T>>& group) const
{
	std::vector<T> values;
	for (int i = 0; i < group.size(); ++i) {
		values.emplace_back(group[i][this->feature_length]);
	}

	std::set<T> vals(values.cbegin(), values.cend());
	int max_count = -1, index = -1;
	for (int i = 0; i < vals.size(); ++i) {
		int count = std::count(values.cbegin(), values.cend(), *std::next(vals.cbegin(), i));
		if (max_count < count) {
			max_count = count;
			index = i;
		}
	}

	return *std::next(vals.cbegin(), index);
}

template<typename T>
void DecisionTree<T>::split(binary_tree* node, int depth)
{
	std::vector<std::vector<T>> left = std::get<2>(node->dict)[0];
	std::vector<std::vector<T>> right = std::get<2>(node->dict)[1];
	std::get<2>(node->dict).clear();

	// check for a no split
	if (left.size() == 0 || right.size() == 0) {
		for (int i = 0; i < right.size(); ++i) {
			left.emplace_back(right[i]);
		}

		node->class_value_left = node->class_value_right = to_terminal(left);
		return;
	}

	// check for max depth
	if (depth >= max_depth) {
		node->class_value_left = to_terminal(left);
		node->class_value_right = to_terminal(right);
		return;
	}

	// process left child
	if (left.size() <= min_size) {
		node->class_value_left = to_terminal(left);
	} else {
		dictionary dict = get_split(left);
		node->left = new binary_tree;
		node->left->dict = dict;
		split(node->left, depth+1);
	}

	// process right child
	if (right.size() <= min_size) {
		node->class_value_right = to_terminal(right);
	} else {
		dictionary dict = get_split(right);
		node->right = new binary_tree;
		node->right->dict = dict;
		split(node->right, depth+1);
	}
}

template<typename T>
void DecisionTree<T>::build_tree(const std::vector<std::vector<T>>& train)
{
	// create root node
	dictionary root = get_split(train);
	binary_tree* node = new binary_tree;
	node->dict = root;
	tree = node;
	split(node, 1);
}

template<typename T>
void DecisionTree<T>::train()
{
	this->max_nodes = (1 << max_depth) - 1;
	build_tree(src_data);

	accuracy_metric();
	
	//binary_tree* tmp = tree;
	//print_tree(tmp);
}

template<typename T>
T DecisionTree<T>::predict(const std::vector<T>& data) const
{
	if (!tree) {
		fprintf(stderr, "Error, tree is null\n");
		return -1111.f;
	}

	return predict(tree, data);
}

template<typename T>
T DecisionTree<T>::predict(binary_tree* node, const std::vector<T>& data) const
{
	if (data[std::get<0>(node->dict)] < std::get<1>(node->dict)) {
		if (node->left) {
			return predict(node->left, data);
		} else {
			return node->class_value_left;
		}
	} else {
		if (node->right) {
			return predict(node->right, data);
		} else {
			return node->class_value_right;
		}
	}
}

template<typename T>
int DecisionTree<T>::save_model(const char* name) const
{
	std::ofstream file(name, std::ios::out);
	if (!file.is_open()) {
		fprintf(stderr, "open file fail: %s\n", name);
		return -1;
	}

	file<<max_depth<<","<<min_size<<std::endl;
	
	binary_tree* tmp = tree;
	int depth = height_of_tree(tmp);
	CHECK(max_depth == depth);
	
	tmp = tree;
	write_node(tmp, file);

	file.close();
	return 0;
}

template<typename T>
void DecisionTree<T>::write_node(const binary_tree* node, std::ofstream& file) const
{
	/*if (!node) return;

	write_node(node->left, file);
	file<<std::get<0>(node->dict)<<","<<std::get<1>(node->dict)<<","<<node->class_value_left<<","<<node->class_value_right<<std::endl;
	write_node(node->right, file);*/
	
	//typedef std::tuple<int, int, T, T, T> row; // flag, index, value, class_value_left, class_value_right
	std::vector<row_element> vec(this->max_nodes, std::make_tuple(-1, -1, (T)-1.f, (T)-1.f, (T)-1.f));

	binary_tree* tmp = const_cast<binary_tree*>(node);
	node_to_row_element(tmp, vec, 0);

	for (const auto& row : vec) {
		file<<std::get<0>(row)<<","<<std::get<1>(row)<<","<<std::get<2>(row)<<","<<std::get<3>(row)<<","<<std::get<4>(row)<<std::endl;
	}
}

template<typename T>
void DecisionTree<T>::node_to_row_element(binary_tree* node, std::vector<row_element>& rows, int pos) const
{
	if (!node) return;

	rows[pos] = std::make_tuple(0, std::get<0>(node->dict), std::get<1>(node->dict), node->class_value_left, node->class_value_right); // 0: have node, -1: no node
	
	if (node->left) node_to_row_element(node->left, rows, 2*pos+1);
	if (node->right) node_to_row_element(node->right, rows, 2*pos+2);
}

template<typename T>
int DecisionTree<T>::height_of_tree(const binary_tree* node) const
{
	if (!node)
		return 0;
	else
		return std::max(height_of_tree(node->left), height_of_tree(node->right)) + 1;
}

template<typename T>
int DecisionTree<T>::load_model(const char* name)
{
	std::ifstream file(name, std::ios::in);
	if (!file.is_open()) {
		fprintf(stderr, "open file fail: %s\n", name);
		return -1;
	}

	std::string line, cell;
	std::getline(file, line);
	std::stringstream line_stream(line);
	std::vector<int> vec;
	int count = 0;
	while (std::getline(line_stream, cell, ',')) {
		vec.emplace_back(std::stoi(cell));
	}
	CHECK(vec.size() == 2);
	max_depth = vec[0];
	min_size = vec[1];
	max_nodes = (1 << max_depth) - 1;
	std::vector<row_element> rows(max_nodes);
	
	if (typeid(float).name() == typeid(T).name()) {
		while (std::getline(file, line)) {
			std::stringstream line_stream2(line);
			std::vector<T> vec2;
		
			while(std::getline(line_stream2, cell, ',')) {
				vec2.emplace_back(std::stof(cell));
			}
			
			CHECK(vec2.size() == 5);
			rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec2[4]);
			//fprintf(stderr, "%d, %d, %f, %f, %f\n", std::get<0>(rows[count]), std::get<1>(rows[count]), std::get<2>(rows[count]), std::get<3>(rows[count]), std::get<4>(rows[count]));
			++count;
		}
	} else { // double
		while (std::getline(file, line)) {
			std::stringstream line_stream2(line);
			std::vector<T> vec2;
		
			while(std::getline(line_stream2, cell, ',')) {
				vec2.emplace_back(std::stod(cell));
			}

			CHECK(vec2.size() == 5);
			rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec[4]);
			++count;
		}
	}

	CHECK(max_nodes == count);
	CHECK(std::get<0>(rows[0]) != -1);

	binary_tree* tmp = new binary_tree;
	std::vector<std::vector<std::vector<T>>> dump;
	tmp->dict = std::make_tuple(std::get<1>(rows[0]), std::get<2>(rows[0]), dump);
	tmp->class_value_left = std::get<3>(rows[0]);
	tmp->class_value_right = std::get<4>(rows[0]);
	tree = tmp;
	row_element_to_node(tmp, rows, max_nodes, 0);

	file.close();
	return 0;
}

template<typename T>
void DecisionTree<T>::row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos)
{
	if (!node || n == 0) return;

	int new_pos = 2 * pos + 1;
	if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {
		node->left = new binary_tree;
		std::vector<std::vector<std::vector<T>>> dump;
		node->left->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump);
		node->left->class_value_left = std::get<3>(rows[new_pos]);
		node->left->class_value_right = std::get<4>(rows[new_pos]);

		row_element_to_node(node->left, rows, n, new_pos);
	}

	new_pos = 2 * pos + 2;
	if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {
		node->right = new binary_tree;
		std::vector<std::vector<std::vector<T>>> dump;
		node->right->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump);
		node->right->class_value_left = std::get<3>(rows[new_pos]);
		node->right->class_value_right = std::get<4>(rows[new_pos]);
	
		row_element_to_node(node->right, rows, n, new_pos);
	}
}

template<typename T>
void DecisionTree<T>::delete_tree()
{
	delete_node(tree);
}

template<typename T>
void DecisionTree<T>::delete_node(binary_tree* node)
{
	if (node->left) delete_node(node->left);
	if (node->right) delete_node(node->right);
	delete node;
}

template<typename T>
double DecisionTree<T>::accuracy_metric() const
{
	int correct = 0;
	for (int i = 0; i < this->samples_num; ++i) {
		T predicted = predict(tree, src_data[i]);
		if (predicted == src_data[i][this->feature_length])
			++correct;
	}

	double accuracy = correct / (double)samples_num * 100.;
	fprintf(stdout, "train accuracy: %f\n", accuracy);

	return accuracy;  
}

template<typename T>
void DecisionTree<T>::print_tree(const binary_tree* node, int depth) const
{
	if (node) {
		std::string blank = " ";
		for (int i = 0; i < depth; ++i) blank += blank;
		fprintf(stdout, "%s[X%d < %.3f]\n", blank.c_str(), std::get<0>(node->dict)+1, std::get<1>(node->dict));

		if (!node->left || !node->right)
			blank += blank;

		if (!node->left)
			fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_left);
		else 
			print_tree(node->left, depth+1);

		if (!node->right)
			fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_right);
		else
			print_tree(node->right, depth+1);
			
	}
}

template class DecisionTree<float>;
template class DecisionTree<double>;

} // namespace ANN

对外提供两个接口,一个是test_decision_tree_train用于训练,一个是test_decision_tree_predict用于测试,其code如下:

// =============================== decision tree ==============================
int test_decision_tree_train()
{
	// small dataset test
	/*const std::vector<std::vector<float>> data{ { 2.771244718f, 1.784783929f, 0.f },
					{ 1.728571309f, 1.169761413f, 0.f },
					{ 3.678319846f, 2.81281357f, 0.f },
					{ 3.961043357f, 2.61995032f, 0.f },
					{ 2.999208922f, 2.209014212f, 0.f },
					{ 7.497545867f, 3.162953546f, 1.f },
					{ 9.00220326f, 3.339047188f, 1.f },
					{ 7.444542326f, 0.476683375f, 1.f },
					{ 10.12493903f, 3.234550982f, 1.f },
					{ 6.642287351f, 3.319983761f, 1.f } };

	const std::vector<float> classes{ 0.f, 1.f };

	ANN::DecisionTree<float> dt;
	dt.init(data, classes);
	dt.set_max_depth(3);
	dt.set_min_size(1);

	dt.train();
#ifdef _MSC_VER
	const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
	const char* model_name = "data/decision_tree.model";
#endif
	dt.save_model(model_name);

	ANN::DecisionTree<float> dt2;
	dt2.load_model(model_name);
	const std::vector<std::vector<float>> test{{0.6f, 1.9f, 0.f}, {9.7f, 4.3f, 1.f}};
	for (const auto& row : test) {
		float ret = dt2.predict(row);
		fprintf(stdout, "predict result: %.1f, actural value: %.1f\n", ret, row[2]);
	} */

	// banknote authentication dataset
#ifdef _MSC_VER
	const char* file_name = "E:/GitCode/NN_Test/data/database/BacknoteDataset/data_banknote_authentication.txt";
#else
	const char* file_name = "data/database/BacknoteDataset/data_banknote_authentication.txt";
#endif

	std::vector<std::vector<float>> data;
	int ret = read_txt_file<float>(file_name, data, ',', 1372, 5);
	if (ret != 0) {
		fprintf(stderr, "parse txt file fail: %s\n", file_name);
		return -1;
	}

	//fprintf(stdout, "data size: rows: %d\n", data.size());

	const std::vector<float> classes{ 0.f, 1.f };
	ANN::DecisionTree<float> dt;
	dt.init(data, classes);
	dt.set_max_depth(6);
	dt.set_min_size(10);
	dt.train();
#ifdef _MSC_VER
	const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
	const char* model_name = "data/decision_tree.model";
#endif
	dt.save_model(model_name);

	return 0;
}

int test_decision_tree_predict()
{
#ifdef _MSC_VER
	const char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#else
	const char* model_name = "data/decision_tree.model";
#endif
	ANN::DecisionTree<float> dt;
	dt.load_model(model_name);
	int max_depth = dt.get_max_depth();
	int min_size = dt.get_min_size();
	fprintf(stdout, "max_depth: %d, min_size: %d\n", max_depth, min_size);

	std::vector<std::vector<float>> test {{-2.5526,-7.3625,6.9255,-0.66811,1},
				       {-4.5531,-12.5854,15.4417,-1.4983,1},
				       {4.0948,-2.9674,2.3689,0.75429,0},
				       {-1.0401,9.3987,0.85998,-5.3336,0},
				       {1.0637,3.6957,-4.1594,-1.9379,1}};
	for (const auto& row : test) {	
		float ret = dt.predict(row);
		fprintf(stdout, "predict result: %.1f, actual value: %.1f\n", ret, row[4]);
	}

	return 0;
}

训练接口执行结果如下:

测试接口执行结果如下:

训练时生成的模型decison_tree.model内容如下:

6,10
0,0,0.3223,-1,-1
0,1,7.6274,-1,-1
0,2,-4.3839,-1,-1
0,0,-0.39816,-1,-1
0,0,-4.2859,-1,-1
0,0,4.2164,-1,0
0,0,1.594,-1,-1
0,2,6.2204,-1,-1
0,1,5.8974,-1,-1
0,0,-5.4901,-1,1
0,0,-1.5768,-1,-1
0,0,0.47368,1,-1
-1,-1,-1,-1,-1
0,2,-2.2718,-1,-1
0,0,2.0421,-1,-1
0,1,7.3273,-1,1
0,1,-4.6062,-1,-1
0,2,3.1143,-1,-1
0,0,0.049175,0,0
0,0,-6.2003,1,1
-1,-1,-1,-1,-1
0,0,-2.7419,0,-1
0,0,-1.5768,0,0
-1,-1,-1,-1,-1
0,0,0.47368,1,1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,1,7.6377,-1,0
0,3,0.097399,-1,-1
0,2,-2.3386,1,-1
0,0,3.6216,-1,-1
0,0,-1.3971,1,1
-1,-1,-1,-1,-1
0,0,-1.6677,1,1
0,0,-1.7781,0,0
0,0,-0.36506,1,1
0,3,1.547,0,1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,0,-2.7419,0,0
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,0,1.0552,1,1
-1,-1,-1,-1,-1
0,0,0.4339,0,0
0,2,2.0013,1,0
-1,-1,-1,-1,-1
0,0,1.8993,0,0
0,0,3.4566,0,0
0,0,3.6216,0,0

GitHub: https://github.com/fengbingchun/NN_Test 

  • 5
    点赞
  • 49
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值