最近阅读机器学习的决策树,突发奇想写了一下ID3实现,未完待续,保存下来等有时间接着写,大体框架是这样
目前还未写样例运行,未实例化
实现指导如下:
#include <iostream>
#include <string>
#include <decisionTree.h>
using namespace std;
int main()
{
AA attr;
DD sample;
decisionTree *m=new decisionTree();
NODE *root;
m->TreeGenerate(sample,attr,root);
}
#ifndef DECISIONTREE_H
#define DECISIONTREE_H
#define MAX 10000
#include <string.h>
#include <iostream>
#include <vector>
using namespace std;
struct A{
int id; //属性ID
string name; //属性名字
int value; //属性值
};
struct AA
{
int a_num; //属性集的大小
vector<A> aa; //属性集
vector<bool> canUse;
};
struct D //(xi,yi)
{
int x_num; // 属性X的长度
int y_num; // y的类别个数
vector<A> x;// xi向量值
int y; // yi值
};
struct DD //样本集
{
int d_num;
vector<D> dd;
};
struct re_A{ //返回属性值信息
int id;
string name;
int value_num; //属性值的个数
vector<int> values ; //所有属性值
};
struct NODE
{
int type; //0是root,1是中节点,2是叶节点
int CLASS; //分类类别,同Y
vector <NODE*> child;
int child_num; //子节点数
re_A attribute; //节点分类属性
};
class decisionTree
{
public:
decisionTree();
virtual ~decisionTree();
void TreeGenerate(DD d,AA a,NODE *node);
bool allSameType(DD d); //样本是否都属于一个类
bool allSameinA(DD d); //样本是否在属性集上相同,不可分
int mostClass(DD d); //样本集属于最多的类别
re_A findBestA(DD d, AA a);
DD findD_inA(DD d,A a);
double Ent(DD d);
double Gain(DD d,re_A a);
re_A find_reA(DD d,A a);
protected:
private:
};
#endif // DECISIONTREE_H
#include "decisionTree.h"
#include <iostream>
#include <stdio.h>
#include <string>
#include <memory>
#include <math.h>
#include <set>
using namespace std;
decisionTree::decisionTree()
{
//ctor
}
decisionTree::~decisionTree()
{
//dtor
}
double decisionTree::Ent(DD d)
{
int re[MAX];
memset(re,0,sizeof(re));
for(int i=1;i<=d.d_num;i++)
{
re[d.dd[i].y]++;
}
double sum=0.0;
for(int i=1;i<=d.dd[0].y_num;i++)
{
double a=1.0*re[i]/d.d_num;
sum+=log(a)/log(2)*a;
}
sum=-1*sum;
return sum;
}
double decisionTree::Gain(DD d,re_A a) //在一个属性a上的信息增益
{
double re[MAX];
double all=0.0;
int num[MAX];
memset(re,0.0,sizeof(re));
memset(num,0,sizeof(num));
for(int i=1;i<=a.value_num;i++)
{
A aa;
aa.id=a.id;
aa.name=a.name;
aa.value=a.values[i];
DD sd;
sd=findD_inA(d,aa);
num[i]=sd.d_num;
re[i]=Ent(sd);
all=all-num[i]/d.d_num*re[i];
}
re[0]=Ent(d);
all+=re[0];
return all;
}
bool decisionTree::allSameType(DD d)
{
for(int i=0;i<d.d_num;i++)
{
if(d.dd[i].y!=d.dd[0].y)
return false;
}
return true;
}
bool decisionTree::allSameinA(DD d)
{
for(int i=0;i<d.dd[0].x_num;i++) //对所有属性
for(int j=0;j<d.d_num;j++) //对每个属性上的样例
{
if(d.dd[j].x[i].value!=d.dd[0].x[i].value)
return false;
}
}
int decisionTree::mostClass(DD d)
{
int result[MAX],cla;
memset(result,0,sizeof(result));
result[0]=-1;
for(int i=0;i<d.d_num;i++)
result[d.dd[i].y]++;
for(int i=1;i<=d.dd[0].y_num;i++)
{
if(result[i]>result[i-1])
cla=i; //最多的类
}
return cla;
}
re_A decisionTree::find_reA(DD d,A a)
{
set<int> s;
for(int i=1;i<=d.d_num;i++)
{
s.insert(d.dd[i].x[a.id].value);
}
re_A c;
c.value_num=s.size();
c.id=a.id;
c.name=a.name;
set<int>::reverse_iterator rit;
for(rit=s.rbegin();rit!=s.rend();rit++)
{
c.values.push_back(*rit);
}
}
re_A decisionTree::findBestA(DD d, AA a)
{
double ga[MAX];
memset(ga,-1.0,sizeof(ga));
for(int i=1;i<=a.a_num;i++)
{
if(a.canUse[i])
{
re_A each=find_reA(d,a.aa[i]);
ga[i]=Gain(d,each);
}
}
double best=-10;
int id=-1;
for(int i=1;i<=a.a_num;i++)
{
if(ga[i]>best)
{
best=ga[i];
id=i;
}
}
re_A be;
be=find_reA(d,a.aa[id]);
return be;
}
DD decisionTree::findD_inA(DD d,A a)
{
DD subD;
int s=0;
for(int i=1;i<=d.d_num;i++)
{
if(d.dd[i].x[a.id].value==a.value)
{
s++;
subD.dd.push_back(d.dd[i]);
}
}
subD.d_num=s;
return subD;
}
void decisionTree::TreeGenerate(DD d, AA a, NODE *node)
{
if (allSameType(d))
{
node->type=3; //叶节点
node->CLASS=d.dd[0].y;
//node->sample=d;
return ;
}
else if(a.a_num==0||allSameinA(d))
{
node->type=3;
node->CLASS=mostClass(d);
//node->sample=d;
return ;
}
else
{
re_A best_a;
best_a=findBestA(d,a);
node->attribute=best_a;
for(int i=1;i<=best_a.value_num;i++)
{
NODE* c_node;
node->child_num ++;
node->child.push_back(c_node);
A c_a;
c_a.id=best_a.id;
c_a.name=best_a.name;
c_a.value=best_a.values[i];
DD c_d;
c_d=findD_inA(d,c_a);
if(c_d.d_num=0)
{
node->child[node->child_num]->type=3;
node->child[node->child_num]->CLASS=mostClass(d);
//node->sample=d;
return ;
}
else{
a.canUse[best_a.id]=false; //去除a中使用的属性
TreeGenerate(c_d,a,node->child[node->child_num]);
}
}
}
}