决策树C++实现,ID3,C4.5,CART

#include <iostream>
#include <vector>
#include <queue>
#include <math.h>
#include <fstream>

using namespace std;

struct node {
    node *left;
    node *right;
    int feature_id;
    float threshold;
};
/************************************  ID3  ****************************************/
float ent(vector<vector<float>> &data) {
    int one = 0, label = data[0].size() - 1;
    for(int j = 0; j < data.size(); ++j) {
        if(data[j][label] == 1) one++;
    }
    if(one == 0 || one == data.size()) return 0;
    float tmp1 = 1.0 * one / data.size();
    float tmp0 = 1 - tmp1;
    return -tmp0*log2(tmp0)-tmp1*log2(tmp1);
}

float cal_ent(vector<vector<float>> &data, int &feature_id, float &threshold) {
    vector<vector<float>> data1, data2;
    for(int j = 0; j < data.size(); ++j) {
        if(data[j][feature_id] < threshold) data1.push_back(data[j]);
        else data2.push_back(data[j]);
    }
    float tmp = 1.0*data2.size()/data.size()*ent(data2);
    if(data1.size() > 0) tmp += 1.0*data1.size()/data.size()*ent(data1);
    return tmp;
}
/************************************  C4.5  ****************************************/
float cal_ent2(vector<vector<float>> &data, int &feature_id, float &threshold) {
    vector<vector<float>> data1, data2;
    for(int j = 0; j < data.size(); ++j) {
        if(data[j][feature_id] < threshold) data1.push_back(data[j]);
        else data2.push_back(data[j]);
    }
    if(data2.size() == data.size()) return 99;
    float tmp = 1.0*data2.size()/data.size();
    float iv = -tmp*log2(tmp) - (1-tmp)*log2(1-tmp);
    float gg = tmp * ent(data2) + (1-tmp)*ent(data1);
    return gg/iv;
}
/**************************************  CART  *******************************************/
float gini(vector<vector<float>> &data) {
    int one = 0, label = data[0].size() - 1;
    for(int j = 0; j < data.size(); ++j) {
        if(data[j][label] == 1) one++;
    }
    if(one == 0 || one == data.size()) return 0;
    float tmp1 = 1.0 * one / data.size();
    float tmp0 = 1 - tmp1;
    return 1 - tmp0*tmp0 - tmp1*tmp1;
}

float cal_gini(vector<vector<float>> &data, int &feature_id, float &threshold) {
    vector<vector<float>> data1, data2;
    for(int j = 0; j < data.size(); ++j) {
        if(data[j][feature_id] < threshold) data1.push_back(data[j]);
        else data2.push_back(data[j]);
    }
    float tmp = 1.0*data2.size()/data.size()*gini(data2);
    if(data1.size() > 0) tmp += 1.0*data1.size()/data.size()*gini(data1);
    return tmp;
}
/*********************************************************************************************************************/
void split(vector<vector<float>> &data, vector<vector<float>> &data1, vector<vector<float>> &data2, int &feature_id, float &threshold) {
    int feature_size = data[0].size() - 1;
    float final = 99;
    for(int i = 0; i < feature_size; ++i) {
        for(int j = 0; j < data.size(); ++j) {
            float entropy = cal_gini(data, i, data[j][i]);
            // cout << j << " " << i << " " << entropy << endl;
            if(entropy < final) {
                final = entropy;
                feature_id = i;
                threshold = data[j][i];
            }
        }
    }
    for(int j = 0; j < data.size(); ++j) {
        if(data[j][feature_id] < threshold) data1.push_back(data[j]);
        else data2.push_back(data[j]);
    } 
}

node *build(vector<vector<float>> data) {
    if(data.size() == 0) return NULL;
    int label = data[0].size() - 1;
    bool ok = true;
    for(int i = 0; i < data.size() - 1; ++i) {
        if(data[i][label] != data[i+1][label]) {
            ok = false;
            break;
        }
    }
    if(ok) {
        node *leaf = new node;
        leaf->feature_id = data[0][label];
        leaf->threshold = -99;
        leaf->left = NULL;
        leaf->right = NULL;
        return leaf;
    }
    vector<vector<float>> data1, data2;
    int feature_id;
    float threshold;
    // data不为空且包含不同标签
    split(data, data1, data2, feature_id, threshold);
    // cout << feature_id << "  " << threshold << endl;
    node *father = new node;
    father->left = build(data1);
    father->right = build(data2);
    father->feature_id = feature_id;
    father->threshold = threshold;
    return father;
}

void print_tree(node *tree) {
    if(tree == NULL) return;
    queue<node*> q;
    queue<int> qq;
    int cur = 1;
    q.push(tree);
    qq.push(1);
    while(!q.empty()) {
        node *tmp = q.front();
        q.pop();
        int tp = qq.front();
        qq.pop();
        if(tp > cur) {
            cout << endl;
            cur = tp;
        }
        if(pow(2, tp-1) > 10) break;
        if(tmp == NULL) {
            cout << "bb bb\t";
            q.push(NULL);
            q.push(NULL);
            qq.push(tp+1);
            qq.push(tp+1);
            continue;
        }
        cout << tmp->feature_id << " " << tmp->threshold << "\t";
        q.push(tmp->left);
        q.push(tmp->right);
        qq.push(tp+1);
        qq.push(tp+1);
    }
}

int predict(node *tree, vector<float> x) {
    if(x.size() != 4) return -1;
    node *tmp = tree;
    while(tmp->threshold != -99) {
        if(x[tmp->feature_id] < tmp->threshold) tmp = tmp->left;
        else tmp = tmp->right;
    }
    return tmp->feature_id;
}

int main() {
    vector<vector<float>> data;
    ifstream input("../dt.txt");
    for(int i = 0; i < 10; ++i) {
        vector<float> ele(5);
        for(int j = 0; j < 5; ++j) {
            input >> ele[j];
        }
        data.push_back(ele);
    }
    input.close();
    node *tree = build(data);
    print_tree(tree);
    vector<float> x(4);
    x[0] = 69;
    x[1] = 100;
    x[2] = 7;
    x[3] = 100;
    cout << predict(tree, x) << endl;
    return 0;
}
  • dt.txt内容为
99	80	5	90	1
89	100	6	100	1
50	60	8	70	0
95	70	9	80	0
98	60	10	80	1
92	65	11	100	1
91	80	12	85	1
85	80	13	95	1
85	91	14	98	1
60	100	7	100	0
  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

刀么克瑟拉莫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值