决策树学习 之 ID3 C++STL代码实现

很久没写含这么多stl的程序了,很故意的用set,map,vector,熟手一下。

也记录一下吧,虽然写得比较渣。


三个文件:

测试数据:data.txt

[plain] view plain copy
  1. D1    Sunny        Hot    High        Weak    No  
  2. D2    Sunny        Hot    High        Strong    No  
  3. D3    Overcast    Hot    High        Weak    Yes  
  4. D4    Rain        Mild    High        Weak    Yes  
  5. D5    Rain        Cool    Normal        Weak    Yes  
  6. D6    Rain        Cool    Normal        Strong    No  
  7. D7    Overcast    Cool    Normal        Strong    Yes  
  8. D8    Sunny        Mild    High        Weak    No  
  9. D9    Sunny        Cool    Normal        Weak    Yes  
  10. D10    Rain        Mild    Normal        Weak    Yes  
  11. D11    Sunny        Mild    Normal        Strong    Yes  
  12. D12    Overcast    Mild    High        Strong    Yes  
  13. D13    Overcast    Hot    Normal        Weak    Yes  
  14. D14    Rain        Mild    High        Strong    No  


程序头文件:id3.h
  1. #ifndef ID3_H  
  2. #define ID3_H  
  3. #include<fstream>  
  4. #include<iostream>  
  5. #include<vector>  
  6. #include<map>  
  7. #include<set>  
  8. #include<cmath>  
  9. using namespace std;  
  10. const int DataRow=14;  
  11. const int DataColumn=6;  
  12. struct Node  
  13. {  
  14.     double value;//代表此时yes的概率。  
  15.     int attrid;  
  16.     Node * parentNode;  
  17.     vector<Node*> childNode;  
  18. };  
  19. #endif  

程序源文件id3.cpp

  1. #include "id3.h"  
  2.   
  3. string DataTable[DataRow][DataColumn];  
  4. map<string,int> str2int;  
  5. set<int> S;  
  6. set<int> Attributes;  
  7. string attrName[DataColumn]={"Day","Outlook","Temperature","Humidity","Wind","PlayTennis"};  
  8. string attrValue[DataColumn][DataRow]=  
  9. {  
  10.     {},//D1,D2这个属性不需要  
  11.     {"Sunny","Overcast","Rain"},  
  12.     {"Hot","Mild","Cool"},  
  13.     {"High","Normal"},  
  14.     {"Weak","Strong"},  
  15.     {"No","Yes"}  
  16. };  
  17. int attrCount[DataColumn]={14,3,3,2,2,2};  
  18. double lg2(double n)  
  19. {  
  20.     return log(n)/log(2);  
  21. }  
  22. void Init()  
  23. {  
  24.     ifstream fin("data.txt");  
  25.     for(int i=0;i<14;i++)  
  26.     {  
  27.       for(int j=0;j<6;j++)  
  28.       {  
  29.           fin>>DataTable[i][j];  
  30.       }  
  31.     }  
  32.     fin.close();  
  33.     for(int i=1;i<=5;i++)  
  34.     {  
  35.         str2int[attrName[i]]=i;  
  36.         for(int j=0;j<attrCount[i];j++)  
  37.         {  
  38.             str2int[attrValue[i][j]]=j;  
  39.         }  
  40.     }  
  41.     for(int i=0;i<DataRow;i++)  
  42.       S.insert(i);  
  43.     for(int i=1;i<=4;i++)  
  44.       Attributes.insert(i);  
  45. }  
  46.   
  47. double Entropy(const set<int> &s)  
  48. {  
  49.     double yes=0,no=0,sum=s.size(),ans=0;  
  50.     for(set<int>::iterator it=s.begin();it!=s.end();it++)  
  51.     {  
  52.         string s=DataTable[*it][str2int["PlayTennis"]];  
  53.         if(s=="Yes")  
  54.           yes++;  
  55.         else  
  56.           no++;  
  57.     }  
  58.     if(no==0||yes==0)  
  59.       return ans=0;  
  60.     ans=-yes/sum*lg2(yes/sum)-no/sum*lg2(no/sum);  
  61.     return ans;  
  62. }  
  63. double Gain(const set<int> & example,int attrid)  
  64. {  
  65.     int attrcount=attrCount[attrid];  
  66.     double ans=Entropy(example);  
  67.     double sum=example.size();  
  68.     set<int> * pset=new set<int>[attrcount];  
  69.     for(set<int>::iterator it=example.begin();it!=example.end();it++)  
  70.     {  
  71.         pset[str2int[DataTable[*it][attrid]]].insert(*it);  
  72.     }  
  73.     for(int i=0;i<attrcount;i++)  
  74.     {  
  75.         ans-=pset[i].size()/sum*Entropy(pset[i]);  
  76.     }  
  77.     return ans;  
  78. }  
  79. int FindBestAttribute(const set<int> & example,const set<int> & attr)  
  80. {  
  81.     double mx=0;  
  82.     int k=-1;  
  83.     for(set<int>::iterator i=attr.begin();i!=attr.end();i++)  
  84.     {  
  85.         double ret=Gain(example,*i);  
  86.         if(ret>mx)  
  87.         {  
  88.             mx=ret;  
  89.             k=*i;  
  90.         }  
  91.     }  
  92.     if(k==-1)  
  93.       cout<<"FindBestAttribute error!"<<endl;  
  94.     return k;  
  95. }  
  96. Node * Id3_solution(set<int> example,set<int> & attributes,Node * parent)  
  97. {  
  98.     Node *now=new Node;//创建树节点。  
  99.     now->parentNode=parent;  
  100.     if(attributes.empty())//如果此时属性列表已用完,即为空,则返回。  
  101.       return now;  
  102.   
  103.     /* 
  104.      * 统计一下example,如果都为正或者都为负则表示已经抵达决策树的叶子节点 
  105.      * 叶子节点的特征是有childNode为空。 
  106.      */  
  107.     int yes=0,no=0,sum=example.size();  
  108.     for(set<int>::iterator it=example.begin();it!=example.end();it++)  
  109.     {  
  110.         string s=DataTable[*it][str2int["PlayTennis"]];  
  111.         if(s=="Yes")  
  112.           yes++;  
  113.         else  
  114.           no++;  
  115.     }  
  116.     if(yes==sum||yes==0)  
  117.     {  
  118.         now->value=yes/sum;  
  119.         return now;  
  120.     }  
  121.       
  122.   
  123.     /*找到最高信息增益的属性并将该属性从attributes集合中删除*/  
  124.     int bestattrid=FindBestAttribute(example,attributes);  
  125.     now->attrid=bestattrid;  
  126.     attributes.erase(attributes.find(bestattrid));  
  127.       
  128.     /*将exmple根据最佳属性的不同属性值分成几个分支,每个分支有即一个子树*/  
  129.     vector< set<int> > child=vector< set<int> >(attrCount[bestattrid]);  
  130.     for(set<int>::iterator i=example.begin();i!=example.end();i++)  
  131.     {  
  132.         int id=str2int[DataTable[*i][bestattrid]];  
  133.         child[id].insert(*i);  
  134.     }  
  135.     for(int i=0;i<child.size();i++)  
  136.     {  
  137.         Node * ret=Id3_solution(child[i],attributes,now);  
  138.         now->childNode.push_back(ret);  
  139.     }  
  140.     return now;  
  141. }  
  142.   
  143. int main()  
  144. {  
  145.     Init();  
  146.     Node * Root=Id3_solution(S,Attributes,NULL);  
  147.     return 0;  

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在C++ STL中,红黑树的实现被封装在`std::map`和`std::set`这两个容器类中。这两个容器类都是基于红黑树实现的,它们提供了高效的查找、插入和删除操作,保证了元素的有序性。 STL中的红黑树实现与你提供的C++代码略有不同。STL中的红黑树使用节点颜色(红色或黑色)和节点指针(parent、left、right)来表示树的结构,而你提供的代码使用了模板和节点对象来实现。 在STL中,红黑树的插入和删除操作已经被封装在`std::map`和`std::set`中,使用起来非常简单。你只需要包含相应的头文件`<map>`或`<set>`,并使用`std::map`或`std::set`类来定义变量,就可以直接使用红黑树的功能了。 以下是使用STL中红黑树的简单示例: ```cpp #include <map> int main() { std::map<int, std::string> myMap; // 插入元素 myMap.insert(std::make_pair(1, "one")); myMap = "two"; // 查找元素 auto it = myMap.find(1); if (it != myMap.end()) { std::cout << it->second << std::endl; // 输出 "one" } // 删除元素 myMap.erase(2); return 0; } ``` 在上面的示例中,我们使用`std::map`来创建一个键-值对的红黑树。我们使用`insert`函数插入元素,使用`find`函数查找元素,使用`erase`函数删除元素。 总结一下,C++ STL中的红黑树实现被封装在`std::map`和`std::set`中,使用起来非常方便。你可以直接包含相应的头文件,并使用这些类来实现红黑树的功能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值