决策树 CART法

#include <stdio.h>
#include <string.h>
#include <vector>
#include <set>
using namespace std;
const int INF=20,TPN=2;
const double Tgini=0.001;


vector<int> aim(INF);
double gini[INF][INF];//第Xi个特征取值为j的基尼指数

struct TREE{
    int endd=0;
    int WN=0,PN=0;//该节点包含的数量,点的特征数
    int minw,minv;//该节点划分的特征及特征值
    double ming=INF;
    vector<int> Y;//该节点包含点的Y
    vector<vector<int> > X;//该节点包含点的特征值
    set<int> dx[INF],dy;
    TREE* letr=NULL,*ritr=NULL;
};


int CART(TREE *now){
    memset(gini,0,sizeof(gini));
    for(int i=0;i<now->WN;i++){//遍历X的所有特征
        for(int j=0;j<now->dx[i].size();j++){//遍历Xi特征可以取值的范围
            int Aij=0,AijY=0,Bij=0,BijY=0;//符合Xi特征的人数,符合Xi且Y是1的人数,B不符合Xi的人数
            for(int k=0;k<now->PN;k++){//遍历该结点所有点
                if(now->X[k][i]==j){
                    Aij++;
                    if(now->Y[k]==1){//Y取值为1
                        AijY++;
                    }
                }
                if(now->X[k][i]!=j){
                    Bij++;
                    if(now->Y[k]==1){//Y取值为1
                        BijY++;
                    }
                }
            }
            gini[i][j]+=Aij*1.0/now->PN*(2.0*AijY/Aij*(1.0-AijY*1.0/Aij));
            gini[i][j]+=Bij*1.0/now->PN*(2.0*BijY/Bij*(1.0-BijY*1.0/Bij));
        }
        gini[i][now->dx[i].size()]=INF;
        for(int j=0;j<now->dx[i].size();j++){
            if(gini[i][now->dx[i].size()]>gini[i][j]){
                gini[i][now->dx[i].size()]=gini[i][j];
            }
        }
        printf("第%d个参数最小:%lf\n",i,gini[i][now->dx[i].size()]);
    }
    for(int i=0;i<now->WN;i++){//遍历X的所有特征
        for(int j=0;j<now->dx[i].size();j++){//遍历Xi特征可以取值的范围
            if(now->ming>gini[i][now->dx[i].size()]){
                now->ming=gini[i][now->dx[i].size()];
                now->minw=i;
                now->minv=j;
            }
        }
    }
    for(int i=0;i<now->WN;i++){
        printf("setx%d=%d sety=%d*\n",i,now->dx[i].size(),now->dy.size());
        for(int j=0;j<=now->dx[i].size();j++){
            printf("%0.6lf ",gini[i][j]);
        }
        printf("\n");
    }
    for(int i=0;i<now->PN;i++){
        for(int j=0;j<now->WN;j++){
            printf("%d ",now->X[i][j]);
        }
        printf("%d \n",now->Y[i]);
    }
    printf("选择:%d %d %lf\n",now->minw,now->minv,now->ming);
    return 0;
}

TREE *BuildCART(TREE *now){
    CART(now);
    //printf("%0.6lf,,%0.6lf,,,,%d,,%d,,,,%d--\n",now->ming,Tgini,now->PN,TPN,now->WN);
    if(now->dy.size()==1||now->WN==1){
        now->endd=1;
        printf("%d*****\n\n",now->endd);
        return now;
    }
    printf("\n");
    now->letr=new TREE;
    now->ritr=new TREE;
    vector<int> tv;
    for(int i=0;i<now->PN;i++){
        tv.clear();
        int ad=0;
        if(now->X[i][now->minw]==now->minv){
            for(int j=0;j<now->WN;j++){
                if(j==now->minw){//用来区分的特征舍弃
                    ad--;
                    continue;
                }
                //printf("%d...",now->X[i][j]);
                tv.push_back(now->X[i][j]);
                //printf("%d %d\n",j,now->X[i][j]);
                now->letr->dx[j+ad].insert(now->X[i][j]);
            }
            now->letr->X.push_back(tv);
            now->letr->Y.push_back(now->Y[i]);
            now->letr->dy.insert(now->Y[i]);
            now->letr->PN++;
        }else{
            for(int j=0;j<now->WN;j++){
                if(j==now->minw){//用来区分的特征舍弃
                    ad--;
                    continue;
                }
                tv.push_back(now->X[i][j]);
                now->ritr->dx[j+ad].insert(now->X[i][j]);
            }
            now->ritr->X.push_back(tv);
            now->ritr->Y.push_back(now->Y[i]);
            now->ritr->dy.insert(now->Y[i]);
            now->ritr->PN++;

        }
    }
    now->ritr->WN=now->WN-1;
    now->letr->WN=now->WN-1;
    BuildCART(now->letr);
    BuildCART(now->ritr);
    return now;
}

int judge(TREE *now){
    int res=0;
    if(now->endd==1){
        return now->Y[0];
    }
    if(aim[now->minw]==now->minv){
        res=judge(now->letr);
    }else{
        res=judge(now->ritr);
    }
    return res;
}

int main(){
    freopen("in.txt","r",stdin);
    TREE *rt=new TREE;
    scanf("%d %d",&rt->WN,&rt->PN);
    vector<int> tv;
    int temp;
    for(int i=0;i<rt->PN;i++){
        tv.clear();
        for(int j=0;j<rt->WN;j++){
            scanf("%d",&temp);
            tv.push_back(temp);
            rt->dx[j].insert(temp);
        }
        rt->X.push_back(tv);
        scanf("%d",&temp);
        rt->Y.push_back(temp);
        rt->dy.insert(temp);
    }
    rt=BuildCART(rt);


    int tn;
    scanf("%d",&tn);
    while(tn--){
        for(int i=0;i<rt->WN;i++){
            scanf("%d",&aim[i]);
        }
        printf("预测:%d\n",judge(rt));
    }
    return 0;
}



4 15
0 0 0 0 0
0 0 0 1 0
0 1 0 1 1
0 1 1 0 1
0 0 0 0 0
1 0 0 0 0
1 0 0 1 0
1 1 1 1 1
1 0 1 2 1
1 0 1 2 1
2 0 1 2 1
2 0 1 1 1
2 1 0 1 1
2 1 0 2 1
2 0 0 0 0
9
0 0 0 0
1 1 1 1
1 0 1 2
1 0 1 2
2 0 1 2
2 0 1 1
2 1 0 1
2 1 0 2
2 0 0 0



他喵的,花了我一晚上,结构体指针中的vector在绑定内存前不能进行赋值,我日,调了半天。
这个是没有优化的决策树,佛了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值