机器学习:决策树cart算法在分类与回归的应用(上)

机器学习cart算法的分类树

1、写在前面

决策树是通过一系列规则对数据进行分类的过程。它提供一种在什么条件下会得到什么值的类似规则的方法。决策树分为分类树和回归树两种,分类树对离散变量做决策树,回归树对连续变量做决策树,分类树是从ID3算法开始,改进成C4.5,随后又出现了cart算法,cart算法可以生成分类树,也可以生成回归树,每个决策树之所以不同主要是因为在最优标准属性的选择规则的不同,ID3的最优标准采取的是信息熵,改进成C4.5采取的是信息熵增益,而cart算法在构建分类树采取的标准是基尼指数,在构建回归时采取的标准为最小节点方差,因为ID3、C4.5在其他博客也已经详细介绍了,也通过python算法编程实现,这里就不再叙述,而对于cart算法,本人查阅很多资料以及博客,对其分类树的构建各有不同,而cart算法相对于前两个算法的优势是可以处理大批量的数据,故本人实例编程实现cart算法的分类树,回归树、模型树,以及将回归树与模型树进行了比较。

2、cart算法的分类树:

     分类树的最优标准属性的选择是采取基尼指数来判别的,基尼指数的计算过程本人就不再叙述,在相应的一些博客都介绍了,本人在此只想提出几个重点,这些都是其他一些博主在cart算法构建分类树时的不清不楚的地方:(1)、分类树构建结果肯定是一二叉树。(2)、分类树的剪枝分为前剪枝,后剪枝。前剪枝采取的是卡方检验来判断节点是否有再继续分裂的必要,以及后剪枝采取的是CCP(代价复杂度)剪枝方法,ccp剪枝方法是有个详细的步骤的。很多博主并没有根据这个步骤进行剪枝的。具体步骤在博文cart算法ccp剪枝具体过程(3)对不同分类树在处理连续属性与离散属性时,计算其基尼指数过程是一样的,在编程实现时,遍历属性下的数值时自动将遍历到的相同属性值分为一类,其它不同属性值分为另一类。以下本人贴出满足这三个重点的一位博主代码,当然这博主的代码本人仔细研究后,也发现了他的几处不足,在此本人就班门弄斧的对其修改了,不过我也是在他的代码基础上改进的,非常感谢他的公开,这才加强了我对cart算法分类树的理解。其中不足之处:
(1)、计算节点的表面误差率增益值α后采取的是c++模板priority_queue对每个节点α值进行排序,但是博主并没有实现自主定义的排序方法。

(2)、对生成树的后剪枝,应该利用测试集对其进行剪枝,利用测试集计算每一节点的表面误差率增益值α。

(3)、增加了一些函数,主要目的是利用测试集来计算表面误差率增益值α以及计算每个节点分类的正确率。

以下是本人改进后的分类树代码,在本人觉得不太容易理解的地方本人给出了注释,在此本人也给训练集与测试集链接训练集与测试集 密码 zf0q:

#include
     
     
      
      
#include
      
      
       
       
#include
       
       
        
        
#include
        
        
         
         
#include
         
          #include 
          
            #include 
           
             #include 
            
              #include 
             
               #include 
              
                #include 
               
                 using namespace std; //置信水平取0.95时的卡方表 const double CHI[18] = { 0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962 }; /*根据多维数组计算卡方值*/ template 
                
                  double cal_chi(Comparable **arr, int row, int col) { vector 
                 
                   rowsum(row); vector 
                  
                    colsum(col); Comparable totalsum = static_cast 
                   
                     (0);//强制将0转换为Comparable型 //cout<<"observation"< 
                    
                      right.first; } }; /* 下面这三个数据结构是来存在在哪种属性下的某一类的个数*/ typedef map 
                     
                       MAP_REST_COUNT; typedef map 
                      
                        MAP_ATTR_REST; typedef vector 
                       
                         VEC_STATI; const int ATTR_NUM = 6; //自变量的维度 vector 
                        
                          X(ATTR_NUM); int rest_number; //因变量的种类数,即类别数 vector 
                          
                          
                            > classes; //把类别、对应的记录数存放在一个数组中 int total_record_number; //总的记录数 vector 
                            
                            
                              > inputData; //原始输入数据 vector 
                              
                              
                                > testinputData; //测试输入数据 class node { public: node* parent; //父节点 node* leftchild; //左孩子节点 node* rightchild; //右孩子节点 string cond; //分枝条件 string decision; //在该节点上作出的类别判定 double precision; //判定的正确率 int record_number; //该节点上涵盖的记录个数 int size; //子树包含的叶子节点的数目 int index; //层次遍历树,给节点标上序号 double alpha; //表面误差率的增加量 node() { parent = NULL; leftchild = NULL; rightchild = NULL; precision = 0.0; record_number = 0; size = 1; index = 0; alpha = 1.0; } node(node* p) { parent = p; leftchild = NULL; rightchild = NULL; precision = 0.0; record_number = 0; size = 1; index = 0; alpha = 1.0; } node(node* p, string c, string d) :cond(c), decision(d) { parent = p; leftchild = NULL; rightchild = NULL; precision = 0.0; record_number = 0; size = 1; index = 0; alpha = 1.0; } void printInfo() { cout << "index:" << index << "\tdecisoin:" << decision << "\tprecision:" << precision << "\tcondition:" << cond << "\tsize:" << size; if (parent != NULL) cout << "\tparent index:" << parent->index; if (leftchild != NULL) cout << "\tleftchild:" << leftchild->index << "\trightchild:" << rightchild->index; cout << endl; } void printTree() { printInfo(); if (leftchild != NULL) leftchild->printTree(); if (rightchild != NULL) rightchild->printTree(); } }; /* 读取测试文件数据,采取的是c++字符串流的读取方式 得到结果:testinputData 数据源 */ int readtestInput(string filename) { ifstream ifs(filename.c_str()); if (!ifs) { cerr << "open inputfile failed!" << endl; return -1; } map 
                               
                                 catg; string line; getline(ifs, line); string item; istringstream strstm(line); strstm >> item; for (int i = 0; i 
                                
                                  > item; X[i] = item; } while (getline(ifs, line)) { vector 
                                 
                                   conts(ATTR_NUM + 2); istringstream strstm(line); //strstm.str(line); for (int i = 0; i 
                                  
                                    > item; conts[i] = item; if (i == conts.size() - 1) catg[item]++; } testinputData.push_back(conts); } total_record_number = testinputData.size(); ifs.close(); return 0; } /* 读取文件数据,采取的是c++字符串流的读取方式 得到结果:inputData 数据源 classes 分类标签以及个数(first:哺乳类,second:6) rest_number 分类的种类数 */ int readInput(string filename) { ifstream ifs(filename.c_str()); if (!ifs) { cerr << "open inputfile failed!" << endl; return -1; } map 
                                   
                                     catg; string line; getline(ifs, line); string item; istringstream strstm(line); strstm >> item; for (int i = 0; i 
                                    
                                      > item; X[i] = item; } while (getline(ifs, line)) { vector 
                                     
                                       conts(ATTR_NUM + 2); istringstream strstm(line); //strstm.str(line); for (int i = 0; i 
                                      
                                        > item; conts[i] = item; if (i == conts.size() - 1) catg[item]++; } inputData.push_back(conts); } total_record_number = inputData.size(); ifs.close(); map 
                                       
                                         ::const_iterator itr = catg.begin();//将catg归类结果放入classes中 while (itr != catg.end()) { classes.push_back(make_pair(itr->first, itr->second)); itr++; } rest_number = classes.size();//标签分为几类 return 0; } /*根据inputData作出一个统计stati,统计的是在哪种属性下的某类的个数。*/ void statistic(vector 
                                         
                                         
                                           > &inputData, VEC_STATI &stati) { for (int i = 1; i 
                                          
                                            second).find(rest); if (iter == (itr->second).end()) { (itr->second).insert(make_pair(rest, 1)); } else { iter->second += 1; } } } stati.push_back(attr_rest); } } /*依据某条件作出分枝时,inputData被分成两部分*/ void splitInput(vector 
                                            
                                            
                                              > &inputData, int fitIndex, string cond, vector 
                                              
                                              
                                                > &LinputData, vector 
                                                
                                                
                                                  > &RinputData) { for (int i = 0; i 
                                                 
                                                   > &inputData) { for (int i = 0; i < ATTR_NUM + 2; ++i) { for (int j = 0; j < inputData.size(); ++j) { cout << inputData[j][i] << "\t"; } }cout << endl; } void printStati(VEC_STATI &stati) { for (int i = 0; i 
                                                  
                                                    first; MAP_REST_COUNT::const_iterator iter = (itr->second).begin(); while (iter != (itr->second).end()) { cout << "\t" << iter->first << "\t" << iter->second; iter++; } itr++; cout << endl; } cout << endl; } } void split(node *root, vector 
                                                    
                                                    
                                                      > &inputData, vector 
                                                      
                                                      
                                                        > classes) { //root->printInfo(); root->record_number = inputData.size(); VEC_STATI stati; statistic(inputData, stati); //printStati(stati); //for(int i=0;i 
                                                       
                                                         > fitleftclasses;//左树的分类标签以及个数 vector 
                                                         
                                                         
                                                           > fitrightclasses;//右树的分类标签以及个数 int fitleftnumber;//左树记录数 int fitrightnumber; for (int i = 0; i 
                                                          
                                                            first; //判定的条件,即到达左孩子的条件,属性 //cout<<"cond 为"< 
                                                           
                                                             <<"时:"; vector 
                                                             
                                                             
                                                               > leftclasses(classes); //左孩子节点上类别、及对应的数目 vector 
                                                               
                                                               
                                                                 > rightclasses(classes); //右孩子节点上类别、及对应的数目 int leftnumber = 0; //左孩子节点上包含的类别数目 int rightnumber = 0; //右孩子节点上包含的类别数目 for (int j = 0; j 
                                                                
                                                                  second).find(rest);// if (iter2 == (itr->second).end()) { //没找到,则对应类别以及类别树就全部在右树 leftclasses[j].second = 0; rightnumber += rightclasses[j].second; } else { //找到,则右边树对应的种类以及个数就是总体的减去左边的种类数 leftclasses[j].second = iter2->second; leftnumber += leftclasses[j].second; rightclasses[j].second -= (iter2->second); rightnumber += rightclasses[j].second; } } /**if(leftnumber==0 || rightnumber==0){ cout<<"左右有一边为空"< 
                                                                 
                                                                   cond< 
                                                                  
                                                                    size)++; travel = travel->parent; } node *LChild = new node(root); //创建左右孩子 node *RChild = new node(root); root->leftchild = LChild; root->rightchild = RChild; int maxLcount = 0; int maxRcount = 0; string Ldicision, Rdicision; for (int i = 0; i 
                                                                   
                                                                     maxLcount) { maxLcount = fitleftclasses[i].second; Ldicision = fitleftclasses[i].first; } if (fitrightclasses[i].second>maxRcount) { maxRcount = fitrightclasses[i].second; Rdicision = fitrightclasses[i].first; } } LChild->decision = Ldicision; RChild->decision = Rdicision; //LChild->precision = 1.0*maxLcount / fitleftnumber; //RChild->precision = 1.0*maxRcount / fitrightnumber; /*递归对左右孩子进行分裂*/ vector 
                                                                     
                                                                     
                                                                       > LinputData, RinputData; splitInput(inputData, fitIndex, fitCond, LinputData, RinputData); //cout<<"左边inputData行数:"< 
                                                                      
                                                                        < 
                                                                       
                                                                         > &testinputData) { int i=0; int fitIndex; total_record_number = testinputData.size(); node *LChild= new node(root); node *RChild= new node(root); vector 
                                                                         
                                                                         
                                                                           > LinputData, RinputData; LChild =root->leftchild; RChild = root->rightchild; if (root->leftchild == NULL) return; string cond = root->cond;//分支条件是字符串:属性=属性下的分类,一下是对字符串的操作 string::size_type pos = cond.find("="); string pre = cond.substr(0, pos);//将字符串前0-pos的位置的子字符串赋予pre string post = cond.substr(pos + 1);//在此节点上的分支 for(int index=0;index 
                                                                          
                                                                            record_number = LinputData.size(); RChild->record_number = RinputData.size(); //printinputData(LinputData); //printinputData(RinputData); /*计算正确率*/ for (int j = 0; j < LinputData.size(); ++j) { string rest = LinputData[j][ATTR_NUM + 1];//左树这一行的标签 if (rest == LChild->decision) i++; } if (LChild->record_number == 0) LChild->precision = 0; else LChild->precision=1.0*i/LChild->record_number; i = 0; for (int j = 0; j < RinputData.size(); ++j) { string rest = RinputData[j][ATTR_NUM + 1];//右树这一行的标签 if (rest == RChild->decision) i++; } if (RChild->record_number == 0) RChild->precision=0; else RChild->precision = 1.0*i/RChild->record_number; if(LChild->leftchild!=NULL) pruneprecision(LChild,LinputData); if(RChild->leftchild!=NULL) pruneprecision(RChild, RinputData); } /*计算子树的误差代价*/ double calR2(node *root) { if (root->leftchild == NULL)//叶子结点是没有左右子树的 return (1 - root->precision)*root->record_number / total_record_number; else return calR2(root->leftchild) + calR2(root->rightchild); } /*层次遍历树,给节点标上序号*/ void index(node *root) { int i = 1; queue 
                                                                           
                                                                             que; que.push(root); while (!que.empty()) { node* n = que.front(); que.pop(); n->index = i++; if (n->leftchild != NULL) { que.push(n->leftchild); que.push(n->rightchild); } } } /*层次遍历树,给节点标上序号。同时计算alpha*/ void calalpha(node *root, priority_queue 
                                                                            
                                                                              , MyCompare> &pq) { int i = 1; queue 
                                                                             
                                                                               que; que.push(root); while (!que.empty()) { node* n = que.front(); que.pop(); n->index = i++; if (n->leftchild != NULL) { que.push(n->leftchild); que.push(n->rightchild); //计算表面误差率的增量 double r1 = (1 - n->precision)*n->record_number / total_record_number; //节点的误差代价 double r2 = calR2(n); n->alpha = (r1 - r2) / (n->size - 1); pq.push(MyTriple(n->alpha, n->size, n->index)); } } } /*剪枝*/ void prune(node *root, priority_queue 
                                                                              
                                                                                , MyCompare> &pq) { MyTriple triple = pq.top(); int i = triple.third; queue 
                                                                               
                                                                                 que; que.push(root); while (!que.empty()) { node* n = que.front(); que.pop(); if (n->index == i) { cout << "将要剪掉" << i << "的左右子树" << endl; n->leftchild = NULL; n->rightchild = NULL; int s = n->size - 1; node *trav = n; while (trav != NULL) { trav->size -= s; trav = trav->parent; } break; } else if (n->leftchild != NULL) { que.push(n->leftchild); que.push(n->rightchild); } } } void test(string filename, node *root,int labels) { ifstream ifs(filename.c_str()); if (!ifs) { cerr << "open inputfile failed!" << endl; return; } string line; getline(ifs, line); string item; istringstream strstm(line); //跳过第一行 map 
                                                                                
                                                                                  independent; //自变量,即分类的依据 while (getline(ifs, line)) { istringstream strstm(line); //strstm.str(line); strstm >> item; cout << item << "\t"; for (int i = 0; i 
                                                                                 
                                                                                   > item; independent[X[i]] = item; } node *trav = root; while (trav != NULL) { if (trav->leftchild == NULL) { if (labels >0) { cout << (trav->decision) << "\t置信度:" << (trav->precision) << endl; break; } else cout << (trav->decision) << endl; } string cond = trav->cond;//分支条件是字符串:属性=属性下的分类,一下是对字符串的操作 string::size_type pos = cond.find("="); string pre = cond.substr(0, pos);//将字符串前0-pos的位置的子字符串赋予pre string post = cond.substr(pos + 1); if (independent[pre] == post) trav = trav->leftchild; else trav = trav->rightchild; } } ifs.close(); } int main() { string inputFile = "watermelon.txt"; readInput(inputFile); VEC_STATI stati,teststati; //最原始的统计 statistic(inputData, stati); // for(int i=0;i 
                                                                                  
                                                                                    printTree(); cout << "剪枝前使用该决策树最多进行" << root->size - 1 << "次条件判断" << endl; string testFile = "testwatermelon.txt"; readtestInput(testFile); test(testFile, root,0); /*进行剪枝*/ pruneprecision(root,testinputData); //root->printTree(); priority_queue 
                                                                                   
                                                                                     , MyCompare> pq; calalpha(root,pq); /*//检验一个是不是表面误差增量最小的被剪掉了 while(!pq.empty()){ MyTriple triple=pq.top(); pq.pop(); cout< 
                                                                                    
                                                                                      <<"\t"< 
                                                                                     
                                                                                       <<"\t"< 
                                                                                      
                                                                                        < 
                                                                                       
                                                                                         size - 1 << "次条件判断" << endl; test(testFile, root,1); /*priority_queue 
                                                                                        
                                                                                          pq; calalpha(root, pq); root->printTree(); prune(root, pq); cout << "剪枝后使用该决策树最多进行" << root->size - 1 << "次条件判断" << endl; test(testFile, root);*/ system("pause"); return 0; } 
                                                                                         
                                                                                        
                                                                                       
                                                                                      
                                                                                     
                                                                                    
                                                                                   
                                                                                  
                                                                                 
                                                                                
                                                                               
                                                                              
                                                                             
                                                                            
                                                                           
                                                                          
                                                                         
                                                                        
                                                                       
                                                                      
                                                                     
                                                                    
                                                                   
                                                                  
                                                                 
                                                                
                                                               
                                                              
                                                             
                                                            
                                                           
                                                          
                                                         
                                                        
                                                       
                                                      
                                                     
                                                    
                                                   
                                                  
                                                 
                                                
                                               
                                              
                                             
                                            
                                           
                                          
                                         
                                        
                                       
                                      
                                     
                                    
                                   
                                  
                                 
                                
                               
                              
                             
                            
                           
                          
                         
                        
                       
                      
                     
                    
                   
                  
                 
                
               
              
             
            
          
        
        
       
       
      
      
     
     

最后贴一下代码运行结果图:(第一个是对watermelon数据源的分类树构建结果,以及利用测试集剪枝结果,足见在未剪枝时需进行4次条件判断,错误3个,而剪枝后分类判断

只进行1次,错误也为3个,若训练集数量大,测试集也多,既可以在降低条件判断步数,不降低分类的正确率)


以下是animal数据集的分类树构建与剪枝结果图(注意:要将代码中存放属性维数ATTR_NUM 改为8):




3、写在最后

在认真分析代码后,基本上完成了分类树准确构建(二叉树)以及前后剪枝(卡方值、cpp剪枝),当然,从代码运行结果可以看出,剪枝可以降低判断条件步数,但是有时会降低结果的正确率,这就取决于你偏向于分类速度还是正确度。当然你也可以在加一个判断条件,即在剪枝后的准确率降低了话,你也不采取剪枝。因为这不是重点,代码重点实现cart分类树的构建以及ccp的计算以及剪枝的实现。希望本博客能寄予帮助,下一篇介绍cart算法的回归树以及与模型树的对比,一起学习,一起进行。
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值