#include <iostream>
#include <vector>
#include <queue>
#include <math.h>
#include <fstream>
using namespace std;
struct node {
node *left;
node *right;
int feature_id;
float threshold;
};
float ent(vector<vector<float>> &data) {
int one = 0, label = data[0].size() - 1;
for(int j = 0; j < data.size(); ++j) {
if(data[j][label] == 1) one++;
}
if(one == 0 || one == data.size()) return 0;
float tmp1 = 1.0 * one / data.size();
float tmp0 = 1 - tmp1;
return -tmp0*log2(tmp0)-tmp1*log2(tmp1);
}
float cal_ent(vector<vector<float>> &data, int &feature_id, float &threshold) {
vector<vector<float>> data1, data2;
for(int j = 0; j < data.size(); ++j) {
if(data[j][feature_id] < threshold) data1.push_back(data[j]);
else data2.push_back(data[j]);
}
float tmp = 1.0*data2.size()/data.size()*ent(data2);
if(data1.size() > 0) tmp += 1.0*data1.size()/data.size()*ent(data1);
return tmp;
}
float cal_ent2(vector<vector<float>> &data, int &feature_id, float &threshold) {
vector<vector<float>> data1, data2;
for(int j = 0; j < data.size(); ++j) {
if(data[j][feature_id] < threshold) data1.push_back(data[j]);
else data2.push_back(data[j]);
}
if(data2.size() == data.size()) return 99;
float tmp = 1.0*data2.size()/data.size();
float iv = -tmp*log2(tmp) - (1-tmp)*log2(1-tmp);
float gg = tmp * ent(data2) + (1-tmp)*ent(data1);
return gg/iv;
}
float gini(vector<vector<float>> &data) {
int one = 0, label = data[0].size() - 1;
for(int j = 0; j < data.size(); ++j) {
if(data[j][label] == 1) one++;
}
if(one == 0 || one == data.size()) return 0;
float tmp1 = 1.0 * one / data.size();
float tmp0 = 1 - tmp1;
return 1 - tmp0*tmp0 - tmp1*tmp1;
}
float cal_gini(vector<vector<float>> &data, int &feature_id, float &threshold) {
vector<vector<float>> data1, data2;
for(int j = 0; j < data.size(); ++j) {
if(data[j][feature_id] < threshold) data1.push_back(data[j]);
else data2.push_back(data[j]);
}
float tmp = 1.0*data2.size()/data.size()*gini(data2);
if(data1.size() > 0) tmp += 1.0*data1.size()/data.size()*gini(data1);
return tmp;
}
void split(vector<vector<float>> &data, vector<vector<float>> &data1, vector<vector<float>> &data2, int &feature_id, float &threshold) {
int feature_size = data[0].size() - 1;
float final = 99;
for(int i = 0; i < feature_size; ++i) {
for(int j = 0; j < data.size(); ++j) {
float entropy = cal_gini(data, i, data[j][i]);
if(entropy < final) {
final = entropy;
feature_id = i;
threshold = data[j][i];
}
}
}
for(int j = 0; j < data.size(); ++j) {
if(data[j][feature_id] < threshold) data1.push_back(data[j]);
else data2.push_back(data[j]);
}
}
node *build(vector<vector<float>> data) {
if(data.size() == 0) return NULL;
int label = data[0].size() - 1;
bool ok = true;
for(int i = 0; i < data.size() - 1; ++i) {
if(data[i][label] != data[i+1][label]) {
ok = false;
break;
}
}
if(ok) {
node *leaf = new node;
leaf->feature_id = data[0][label];
leaf->threshold = -99;
leaf->left = NULL;
leaf->right = NULL;
return leaf;
}
vector<vector<float>> data1, data2;
int feature_id;
float threshold;
split(data, data1, data2, feature_id, threshold);
node *father = new node;
father->left = build(data1);
father->right = build(data2);
father->feature_id = feature_id;
father->threshold = threshold;
return father;
}
void print_tree(node *tree) {
if(tree == NULL) return;
queue<node*> q;
queue<int> qq;
int cur = 1;
q.push(tree);
qq.push(1);
while(!q.empty()) {
node *tmp = q.front();
q.pop();
int tp = qq.front();
qq.pop();
if(tp > cur) {
cout << endl;
cur = tp;
}
if(pow(2, tp-1) > 10) break;
if(tmp == NULL) {
cout << "bb bb\t";
q.push(NULL);
q.push(NULL);
qq.push(tp+1);
qq.push(tp+1);
continue;
}
cout << tmp->feature_id << " " << tmp->threshold << "\t";
q.push(tmp->left);
q.push(tmp->right);
qq.push(tp+1);
qq.push(tp+1);
}
}
int predict(node *tree, vector<float> x) {
if(x.size() != 4) return -1;
node *tmp = tree;
while(tmp->threshold != -99) {
if(x[tmp->feature_id] < tmp->threshold) tmp = tmp->left;
else tmp = tmp->right;
}
return tmp->feature_id;
}
int main() {
vector<vector<float>> data;
ifstream input("../dt.txt");
for(int i = 0; i < 10; ++i) {
vector<float> ele(5);
for(int j = 0; j < 5; ++j) {
input >> ele[j];
}
data.push_back(ele);
}
input.close();
node *tree = build(data);
print_tree(tree);
vector<float> x(4);
x[0] = 69;
x[1] = 100;
x[2] = 7;
x[3] = 100;
cout << predict(tree, x) << endl;
return 0;
}
99 80 5 90 1
89 100 6 100 1
50 60 8 70 0
95 70 9 80 0
98 60 10 80 1
92 65 11 100 1
91 80 12 85 1
85 80 13 95 1
85 91 14 98 1
60 100 7 100 0