#include <stdio.h>
#include <string.h>
#include <vector>
#include <set>
using namespace std;
const int INF=20,TPN=2;
const double Tgini=0.001;
vector<int> aim(INF);
double gini[INF][INF];//第Xi个特征取值为j的基尼指数
struct TREE{
int endd=0;
int WN=0,PN=0;//该节点包含的数量,点的特征数
int minw,minv;//该节点划分的特征及特征值
double ming=INF;
vector<int> Y;//该节点包含点的Y
vector<vector<int> > X;//该节点包含点的特征值
set<int> dx[INF],dy;
TREE* letr=NULL,*ritr=NULL;
};
int CART(TREE *now){
memset(gini,0,sizeof(gini));
for(int i=0;i<now->WN;i++){//遍历X的所有特征
for(int j=0;j<now->dx[i].size();j++){//遍历Xi特征可以取值的范围
int Aij=0,AijY=0,Bij=0,BijY=0;//符合Xi特征的人数,符合Xi且Y是1的人数,B不符合Xi的人数
for(int k=0;k<now->PN;k++){//遍历该结点所有点
if(now->X[k][i]==j){
Aij++;
if(now->Y[k]==1){//Y取值为1
AijY++;
}
}
if(now->X[k][i]!=j){
Bij++;
if(now->Y[k]==1){//Y取值为1
BijY++;
}
}
}
gini[i][j]+=Aij*1.0/now->PN*(2.0*AijY/Aij*(1.0-AijY*1.0/Aij));
gini[i][j]+=Bij*1.0/now->PN*(2.0*BijY/Bij*(1.0-BijY*1.0/Bij));
}
gini[i][now->dx[i].size()]=INF;
for(int j=0;j<now->dx[i].size();j++){
if(gini[i][now->dx[i].size()]>gini[i][j]){
gini[i][now->dx[i].size()]=gini[i][j];
}
}
printf("第%d个参数最小:%lf\n",i,gini[i][now->dx[i].size()]);
}
for(int i=0;i<now->WN;i++){//遍历X的所有特征
for(int j=0;j<now->dx[i].size();j++){//遍历Xi特征可以取值的范围
if(now->ming>gini[i][now->dx[i].size()]){
now->ming=gini[i][now->dx[i].size()];
now->minw=i;
now->minv=j;
}
}
}
for(int i=0;i<now->WN;i++){
printf("setx%d=%d sety=%d*\n",i,now->dx[i].size(),now->dy.size());
for(int j=0;j<=now->dx[i].size();j++){
printf("%0.6lf ",gini[i][j]);
}
printf("\n");
}
for(int i=0;i<now->PN;i++){
for(int j=0;j<now->WN;j++){
printf("%d ",now->X[i][j]);
}
printf("%d \n",now->Y[i]);
}
printf("选择:%d %d %lf\n",now->minw,now->minv,now->ming);
return 0;
}
TREE *BuildCART(TREE *now){
CART(now);
//printf("%0.6lf,,%0.6lf,,,,%d,,%d,,,,%d--\n",now->ming,Tgini,now->PN,TPN,now->WN);
if(now->dy.size()==1||now->WN==1){
now->endd=1;
printf("%d*****\n\n",now->endd);
return now;
}
printf("\n");
now->letr=new TREE;
now->ritr=new TREE;
vector<int> tv;
for(int i=0;i<now->PN;i++){
tv.clear();
int ad=0;
if(now->X[i][now->minw]==now->minv){
for(int j=0;j<now->WN;j++){
if(j==now->minw){//用来区分的特征舍弃
ad--;
continue;
}
//printf("%d...",now->X[i][j]);
tv.push_back(now->X[i][j]);
//printf("%d %d\n",j,now->X[i][j]);
now->letr->dx[j+ad].insert(now->X[i][j]);
}
now->letr->X.push_back(tv);
now->letr->Y.push_back(now->Y[i]);
now->letr->dy.insert(now->Y[i]);
now->letr->PN++;
}else{
for(int j=0;j<now->WN;j++){
if(j==now->minw){//用来区分的特征舍弃
ad--;
continue;
}
tv.push_back(now->X[i][j]);
now->ritr->dx[j+ad].insert(now->X[i][j]);
}
now->ritr->X.push_back(tv);
now->ritr->Y.push_back(now->Y[i]);
now->ritr->dy.insert(now->Y[i]);
now->ritr->PN++;
}
}
now->ritr->WN=now->WN-1;
now->letr->WN=now->WN-1;
BuildCART(now->letr);
BuildCART(now->ritr);
return now;
}
int judge(TREE *now){
int res=0;
if(now->endd==1){
return now->Y[0];
}
if(aim[now->minw]==now->minv){
res=judge(now->letr);
}else{
res=judge(now->ritr);
}
return res;
}
int main(){
freopen("in.txt","r",stdin);
TREE *rt=new TREE;
scanf("%d %d",&rt->WN,&rt->PN);
vector<int> tv;
int temp;
for(int i=0;i<rt->PN;i++){
tv.clear();
for(int j=0;j<rt->WN;j++){
scanf("%d",&temp);
tv.push_back(temp);
rt->dx[j].insert(temp);
}
rt->X.push_back(tv);
scanf("%d",&temp);
rt->Y.push_back(temp);
rt->dy.insert(temp);
}
rt=BuildCART(rt);
int tn;
scanf("%d",&tn);
while(tn--){
for(int i=0;i<rt->WN;i++){
scanf("%d",&aim[i]);
}
printf("预测:%d\n",judge(rt));
}
return 0;
}
4 15
0 0 0 0 0
0 0 0 1 0
0 1 0 1 1
0 1 1 0 1
0 0 0 0 0
1 0 0 0 0
1 0 0 1 0
1 1 1 1 1
1 0 1 2 1
1 0 1 2 1
2 0 1 2 1
2 0 1 1 1
2 1 0 1 1
2 1 0 2 1
2 0 0 0 0
9
0 0 0 0
1 1 1 1
1 0 1 2
1 0 1 2
2 0 1 2
2 0 1 1
2 1 0 1
2 1 0 2
2 0 0 0
他喵的,花了我一晚上,结构体指针中的vector在绑定内存前不能进行赋值,我日,调了半天。
这个是没有优化的决策树,佛了。