实验要求
有27个训练样本数据,每个样本的属性表为{Age sex region income married children car mortgage},并做了如下的预处理:
Age:30岁以下标记为“1”;30岁以上50岁以下标记为“2”;50岁以上标记为“3”。
Sex:FEMAL—-“1”;MALE—-“2”
Region:INNER CITY—-“1”; TOWN—-“2”; RURAL—-“3”; SUBURBAN—-“4”
Income:5000~2万—-“1”;2万~4万—-“2”;4万以上—-“3”
Married:已婚—1; 未婚—2
Children:有—-1; 无—2
Car:有—-1; 无—-2
Mortgage:有—1; 无—2;
分类Pep:以上八个条件,若为“是”标记为“1”,若为“否”标记为“2”。
用ID3算法设计实现决策树,实现分类。
代码实现
#include<iostream>
#include<vector>
#include<set>
#include<cmath>
#include<fstream>
using namespace std;
int object[27][9];
int test[37][8];
int obj_size;
int test_size;
class Node{ //节点类
public:
int flag;//叶节点标志
int final;//结果
int judge_attr_value;//属性判断值
int judge_attr_no; //属性编号
set<int> obj;//对象集
int attr[8];//属性可用集
vector<Node *> child;//子节点
}root;
void input(){
ifstream in("sample.txt");
in>>obj_size;
for(int i=0;i<obj_size;i++)
for(int j=0;j<9;j++)
in>>object[i][j];
}
void inputTest(){
ifstream in("test.txt");
in>>test_size;
for(int i=0;i<test_size;i++)
for(int j=0;j<8;j++)
in>>test[i][j];
}
void prepare(){//初始化
root.flag = 0;
root.final = 0;
for(int i=0;i<obj_size;i++) root.obj.insert(i);
for(int i=0;i<8;i++) root.attr[i]=1;
}
int Choose_Attribute(Node *node){ //选择属性
double min=99999;
int min_attr;//最小属性
double size = (double) node->obj.size();//对象个数
for(int i=0;i<8;i++){ //对于每一个属性
if(node->attr[i]){ //属性可选
double attr_classify[6][3]={0};//属性分类
for(set<int>::iterator it = node->obj.begin(); it != node->obj.end(); ++it){//本节点拥有的对象集合的属性分类
attr_classify[object[(*it)][i]][object[(*it)][8]]++;
attr_classify[object[(*it)][i]][0]++;
}
double shang=0;
for(int j=0;j<6;j++){//对于属性的每个值
double s=0;
int num=0;
double T,F;
if(attr_classify[j][0]) { //属性存在,计算熵值
if(attr_classify[j][1]==0) T=0;//第一个结果不存在
if(attr_classify[j][2]==0) F=0;//第二个结果不存在
if(attr_classify[j][1]&&attr_classify[j][2]){//结果都存在
T = attr_classify[j][1]/attr_classify[j][0] * -( log(attr_classify[j][1]/attr_classify[j][0])/log(2));
F = attr_classify[j][2]/attr_classify[j][0] * -( log(attr_classify[j][2]/attr_classify[j][0])/log(2));
}
s = attr_classify[j][0]/size *( T + F );//熵
}
shang += s;
}
if(min>shang){
min=shang;
min_attr=i;//保存属性
}
}
}
return min_attr;//返回最优属性
}
int node_for_each(Node *node){ //节点构造函数
int final_set[3]={0};//结果集
for(set<int>::iterator it = node->obj.begin(); it != node->obj.end(); ++it) final_set[object[*it][8]]++;
if(final_set[1]==0||final_set[2]==0){//属性结果一致
node->flag = 1;
if(final_set[1]) node->final = 1;
if(final_set[2]) node->final = 2;
return 0;
}
if(final_set[1]&&final_set[2]){//属性结果不一致,继续分类
int attribute_no;
int attr[8];
for(int i=0;i<8;i++) attr[i]=node->attr[i];//属性标志数组
attribute_no = Choose_Attribute(node); //选择属性
attr[attribute_no] = 0;//属性已选
set<int> obj_for_attr[6];
for(set<int>::iterator it = node->obj.begin(); it != node->obj.end(); ++it){
for(int i=0;i<6;i++){
obj_for_attr[object[(*it)][attribute_no]].insert((*it));//对于某属性的不同值的对象分类
}
}
for(int i=0;i<6;i++){//对于每一种分类,创建子节点
if(obj_for_attr[i].size()!=0){ //有效分类
Node *childnode =new Node();
childnode->judge_attr_value = i; //节点接受的属性值
//赋值
childnode->obj = obj_for_attr[i];
childnode->judge_attr_no = attribute_no;
for(int j=0;j<8;j++) childnode->attr[j] = attr[j];
//添加到父节点 递归
node->child.push_back(childnode);
node_for_each(childnode);
}
}
}
}
void printTree(Node *node){//打印树
cout<<"节点判断属性:"<<node->judge_attr_no<<" 属性值:"<<node->judge_attr_value<<" 属性个数:"<<node->child.size()<<endl;
for(int i=0;i<node->child.size();i++){
Node *p = node->child[i];
cout<<"-->"<<"子节点判断属性:"<<p->judge_attr_no<<" 属性值:"<<p->judge_attr_value<<" 结果:"<<p->child.size()<<endl;
}
cout<<"_____________________________________________________________________________\n";
for(int i=0;i<node->child.size();i++){
Node *p = node->child[i];
printTree(p);
}
}
int judge(Node *node,int size){//判断
Node *p = node;
int result;
int fz;
for(int i=0;i<size;i++){//test集
p = node;
while(p->flag!=1){
fz = p->child.size();
for(int j=0;j<fz;j++){//对于每个分支
Node *childp = p->child[j];//选择子节点
if(childp->judge_attr_value == test[i][childp->judge_attr_no]){
p=childp;
break;
}
}
}
result = p->final;
cout<<result<<" ";
}
}
int main(){
input();
prepare();
node_for_each(&root);
//printTree(&root);
inputTest();
judge(&root,test_size);
}
sample.txt:
27
1 2 1 1 2 1 1 2 2
1 2 1 1 2 2 2 2 1
2 1 4 1 2 1 2 2 1
2 1 1 1 1 2 2 2 2
1 2 1 1 1 2 2 2 2
1 2 1 1 2 1 2 1 1
2 1 2 1 1 2 1 1 2
2 1 1 1 2 1 1 2 1
2 1 3 1 2 2 1 2 1
2 1 2 2 2 1 2 2 2
2 2 1 2 2 2 2 1 1
2 1 2 2 1 1 2 1 1
2 2 1 2 1 2 2 1 2
1 1 1 2 1 2 2 2 1
3 2 1 2 1 1 1 2 2
1 1 1 2 1 1 1 2 1
1 1 3 2 2 2 1 2 1
3 1 2 2 1 2 2 2 1
3 2 3 3 1 1 1 2 1
3 2 2 3 1 2 1 1 2
3 1 3 3 1 1 2 2 1
3 2 1 3 1 2 1 2 2
3 2 1 3 1 1 1 1 1
3 1 1 3 1 2 1 1 2
3 1 3 3 1 2 2 2 2
3 2 4 3 1 2 2 1 1
3 1 3 3 2 2 1 1 2
test.txt:
37
1 2 1 1 2 1 1 2
1 2 1 1 2 2 2 2
2 1 4 1 2 1 2 2
2 1 1 1 1 2 2 2
1 2 1 1 1 2 2 2
1 2 1 1 2 1 2 1
2 1 2 1 1 2 1 1
2 1 1 1 2 1 1 2
2 1 3 1 2 2 1 2
2 1 2 2 2 1 2 2
2 2 1 2 2 2 2 1
2 1 2 2 1 1 2 1
2 2 1 2 1 2 2 1
1 1 1 2 1 2 2 2
3 2 1 2 1 1 1 2
1 1 1 2 1 1 1 2
1 1 3 2 2 2 1 2
3 1 2 2 1 2 2 2
3 2 3 3 1 1 1 2
3 2 2 3 1 2 1 1
3 1 3 3 1 1 2 2
3 2 1 3 1 2 1 2
3 2 1 3 1 1 1 1
3 1 1 3 1 2 1 1
3 1 3 3 1 2 2 2
3 2 4 3 1 2 2 1
3 1 3 3 2 2 1 1
1 2 1 3 1 2 2 2
1 2 1 3 2 1 1 1
1 1 2 3 1 1 2 1
2 1 4 1 1 2 1 1
1 1 3 2 1 2 2 2
1 2 1 1 2 1 1 1
2 2 2 2 1 2 1 1
1 1 1 1 1 1 1 1
2 1 3 3 2 2 1 2
1 1 2 2 2 1 2 2
此为记录实验用