很久没写含这么多stl的程序了,很故意的用set,map,vector,熟手一下。
也记录一下吧,虽然写得比较渣。
三个文件:
测试数据:data.txt
- D1 Sunny Hot High Weak No
- D2 Sunny Hot High Strong No
- D3 Overcast Hot High Weak Yes
- D4 Rain Mild High Weak Yes
- D5 Rain Cool Normal Weak Yes
- D6 Rain Cool Normal Strong No
- D7 Overcast Cool Normal Strong Yes
- D8 Sunny Mild High Weak No
- D9 Sunny Cool Normal Weak Yes
- D10 Rain Mild Normal Weak Yes
- D11 Sunny Mild Normal Strong Yes
- D12 Overcast Mild High Strong Yes
- D13 Overcast Hot Normal Weak Yes
- D14 Rain Mild High Strong No
- #ifndef ID3_H
- #define ID3_H
- #include<fstream>
- #include<iostream>
- #include<vector>
- #include<map>
- #include<set>
- #include<cmath>
- using namespace std;
- const int DataRow=14;
- const int DataColumn=6;
- struct Node
- {
- double value;//代表此时yes的概率。
- int attrid;
- Node * parentNode;
- vector<Node*> childNode;
- };
- #endif
程序源文件id3.cpp
- #include "id3.h"
- string DataTable[DataRow][DataColumn];
- map<string,int> str2int;
- set<int> S;
- set<int> Attributes;
- string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};
- string attrValue[DataColumn][DataRow]=
- {
- {},//D1,D2这个属性不需要
- {"Sunny","Overcast","Rain"},
- {"Hot","Mild","Cool"},
- {"High","Normal"},
- {"Weak","Strong"},
- {"No","Yes"}
- };
- int attrCount[DataColumn]={14,3,3,2,2,2};
- double lg2(double n)
- {
- return log(n)/log(2);
- }
- void Init()
- {
- ifstream fin("data.txt");
- for(int i=0;i<14;i++)
- {
- for(int j=0;j<6;j++)
- {
- fin>>DataTable[i][j];
- }
- }
- fin.close();
- for(int i=1;i<=5;i++)
- {
- str2int[attrName[i]]=i;
- for(int j=0;j<attrCount[i];j++)
- {
- str2int[attrValue[i][j]]=j;
- }
- }
- for(int i=0;i<DataRow;i++)
- S.insert(i);
- for(int i=1;i<=4;i++)
- Attributes.insert(i);
- }
- double Entropy(const set<int> &s)
- {
- double yes=0,no=0,sum=s.size(),ans=0;
- for(set<int>::iterator it=s.begin();it!=s.end();it++)
- {
- string s=DataTable[*it][str2int["PlayTennis"]];
- if(s=="Yes")
- yes++;
- else
- no++;
- }
- if(no==0||yes==0)
- return ans=0;
- ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);
- return ans;
- }
- double Gain(const set<int> & example,int attrid)
- {
- int attrcount=attrCount[attrid];
- double ans=Entropy(example);
- double sum=example.size();
- set<int> * pset=new set<int>[attrcount];
- for(set<int>::iterator it=example.begin();it!=example.end();it++)
- {
- pset[str2int[DataTable[*it][attrid]]].insert(*it);
- }
- for(int i=0;i<attrcount;i++)
- {
- ans-=pset[i].size()/sum*Entropy(pset[i]);
- }
- return ans;
- }
- int FindBestAttribute(const set<int> & example,const set<int> & attr)
- {
- double mx=0;
- int k=-1;
- for(set<int>::iterator i=attr.begin();i!=attr.end();i++)
- {
- double ret=Gain(example,*i);
- if(ret>mx)
- {
- mx=ret;
- k=*i;
- }
- }
- if(k==-1)
- cout<<"FindBestAttribute error!"<<endl;
- return k;
- }
- Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent)
- {
- Node *now=new Node;//创建树节点。
- now->parentNode=parent;
- if(attributes.empty())//如果此时属性列表已用完,即为空,则返回。
- return now;
- /*
- * 统计一下example,如果都为正或者都为负则表示已经抵达决策树的叶子节点
- * 叶子节点的特征是有childNode为空。
- */
- int yes=0,no=0,sum=example.size();
- for(set<int>::iterator it=example.begin();it!=example.end();it++)
- {
- string s=DataTable[*it][str2int["PlayTennis"]];
- if(s=="Yes")
- yes++;
- else
- no++;
- }
- if(yes==sum||yes==0)
- {
- now->value=yes/sum;
- return now;
- }
- /*找到最高信息增益的属性并将该属性从attributes集合中删除*/
- int bestattrid=FindBestAttribute(example,attributes);
- now->attrid=bestattrid;
- attributes.erase(attributes.find(bestattrid));
- /*将exmple根据最佳属性的不同属性值分成几个分支,每个分支有即一个子树*/
- vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]);
- for(set<int>::iterator i=example.begin();i!=example.end();i++)
- {
- int id=str2int[DataTable[*i][bestattrid]];
- child[id].insert(*i);
- }
- for(int i=0;i<child.size();i++)
- {
- Node * ret=Id3_solution(child[i],attributes,now);
- now->childNode.push_back(ret);
- }
- return now;
- }
- int main()
- {
- Init();
- Node * Root=Id3_solution(S,Attributes,NULL);
- return 0;
- }