CART分类算法

统计学习方法是一本好书,可惜例子太少。找到一个好的CART算法的例子

谢谢原文作者了

http://www.cnblogs.com/zhangchaoyang  作者:Orisun

分类回归树(CART,Classification And Regression Tree)也属于一种决策树,上回文我们介绍了基于ID3算法的决策树。作为上篇,这里只介绍CART是怎样用于分类的。

分类回归树是一棵二叉树,且每个非叶子节点都有两个孩子,所以对于第一棵子树其叶子节点数比非叶子节点数多1。

表1

名称体温表面覆盖胎生产蛋能飞水生有腿冬眠类标记
恒温毛发哺乳类
巨蟒冷血鳞片爬行类
鲑鱼冷血鳞片鱼类
恒温毛发哺乳类
冷血有时两栖类
巨蜥冷血鳞片爬行类
蝙蝠恒温毛发哺乳类
恒温哺乳类
豹纹鲨冷血鳞片鱼类
海龟冷血鳞片有时爬行类
豪猪恒温刚毛哺乳类
冷血鳞片鱼类
蝾螈冷血有时两栖类

上例是属性有8个,每个属性又有多少离散的值可取。在决策树的每一个节点上我们可以按任一个属性的任一个值进行划分。比如最开始我们按:

1)表面覆盖为毛发和非毛发

2)表面覆盖为鳞片和非鳞片

3)体温为恒温和非恒温

等等产生当前节点的左右两个孩子。按哪种划分最好呢?有3个标准可以用来衡量划分的好坏:GINI指数、双化指数、有序双化指数。下面我们只讲GINI指数。

GINI指数

总体内包含的类别越杂乱,GINI指数就越大(跟熵的概念很相似)。比如体温为恒温时包含哺乳类5个、鸟类2个,则:

体温为非恒温时包含爬行类3个、鱼类3个、两栖类2个,则

所以如果按照“体温为恒温和非恒温”进行划分的话,我们得到GINI的增益(类比信息增益):

最好的划分就是使得GINI_Gain最小的划分。

终止条件

一个节点产生左右孩子后,递归地对左右孩子进行划分即可产生分类回归树。这里的终止条件是什么?什么时候节点就可以停止分裂了?直观的情况,当节点包含的数据记录都属于同一个类别时就可以终止分裂了。这只是一个特例,更一般的情况我们计算χ2值来判断分类条件和类别的相关程度,当χ2很小时说明分类条件和类别是独立的,即按照该分类条件进行分类是没有道理的,此时节点停止分裂。注意这里的“分类条件”是指按照GINI_Gain最小原则得到的“分类条件”。

假如在构造分类回归树的第一步我们得到的“分类条件”是:体温为恒温和非恒温。此时:

 哺乳类爬行类鱼类鸟类两栖类
恒温50020
非恒温03302

我在《独立性检验》中讲述了χ2的计算方法。当选定置信水平后查表可得“体温”与动物类别是否相互独立。

还有一种方式就是,如果某一分支覆盖的样本的个数如果小于一个阈值,那么也可产生叶子节点,从而终止Tree-Growth。

如何确定叶子节点的类?

前面提到Tree-Growth终止的方式有2种,对于第一种方式,叶子节点覆盖的样本都属于同一类,那么这种情况下叶子节点的类自然不必多言。对于第二种方式,叶子节点覆盖的样本未必属于同一类,直接一点的方法就是,该叶子节点所覆盖的样本哪个类占大多数,那么该叶子节点的类别就是那个占大多数的类。


剪枝

当分类回归树划分得太细时,会对噪声数据产生过拟合作用。因此我们要通过剪枝来解决。剪枝又分为前剪枝和后剪枝:前剪枝是指在构造树的过程中就知道哪些节点可以剪掉,于是干脆不对这些节点进行分裂,在N皇后问题和背包问题中用的都是前剪枝,上面的χ2方法也可以认为是一种前剪枝;后剪枝是指构造出完整的决策树之后再来考查哪些子树可以剪掉。

在分类回归树中可以使用的后剪枝方法有多种,比如:代价复杂性剪枝、最小误差剪枝、悲观误差剪枝等等。这里我们只介绍代价复杂性剪枝法。

对于分类回归树中的每一个非叶子节点计算它的表面误差率增益值α。

是子树中包含的叶子节点个数;

是节点t的误差代价,如果该节点被剪枝;

r(t)是节点t的误差率;

p(t)是节点t上的数据占所有数据的比例。

是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。

比如有个非叶子节点t4如图所示:

已知所有的数据总共有60条,则节点t4的节点误差代价为:

子树误差代价为:

以t4为根节点的子树上叶子节点有3个,最终:

找到α值最小的非叶子节点,令其左右孩子为NULL。当多个非叶子节点的α值同时达到最小时,取最大的进行剪枝。

#include<iostream>
#include<fstream>
#include<sstream>
#include<string>
#include<map>
#include<list>
#include<set>
#include<queue>
#include<utility>
#include<vector>
#include<cmath>
 
using namespace std;
 
//置信水平取0.95时的卡方表
const double CHI[18]={0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962};
/*根据多维数组计算卡方值*/
template<typename Comparable>
double cal_chi(Comparable **arr,int row,int col){
    vector<Comparable> rowsum(row);
    vector<Comparable> colsum(col);
    Comparable totalsum=static_cast<Comparable>(0);
    //cout<<"observation"<<endl;
    for(int i=0;i<row;++i){
        for(int j=0;j<col;++j){
            //cout<<arr[i][j]<<"\t";
            totalsum+=arr[i][j];
            rowsum[i]+=arr[i][j];
            colsum[j]+=arr[i][j];
        }
        //cout<<endl;
    }
    double rect=0.0;
    //cout<<"exception"<<endl;
    for(int i=0;i<row;++i){
        for(int j=0;j<col;++j){
            double excep=1.0*rowsum[i]*colsum[j]/totalsum;
            //cout<<excep<<"\t";
            if(excep!=0)
                rect+=pow(arr[i][j]-excep,2.0)/excep;
        }
        //cout<<endl;
    }
    return rect;
}
 
class MyTriple{
public:
    double first;
    int second;
    int third;
    MyTriple(){
        first=0.0;
        second=0;
        third=0;
    }
    MyTriple(double f,int s,int t):first(f),second(s),third(t){}
    bool operator< (const MyTriple &obj) const{
        int cmp=this->first-obj.first;
        if(cmp>0)
            return false;
        else if(cmp<0)
            return true;
        else{
            cmp=obj.second-this->second;
            if(cmp<0)
                return true;
            else
                return false;
        }
    }
};
 
typedef map<string,int> MAP_REST_COUNT;
typedef map<string,MAP_REST_COUNT> MAP_ATTR_REST;
typedef vector<MAP_ATTR_REST> VEC_STATI;
 
const int ATTR_NUM=8;       //自变量的维度
vector<string> X(ATTR_NUM);
int rest_number;        //因变量的种类数,即类别数
vector<pair<string,int> > classes;      //把类别、对应的记录数存放在一个数组中
int total_record_number;        //总的记录数
vector<vector<string> > inputData;      //原始输入数据
 
class node{
public:
    node* parent;       //父节点
    node* leftchild;        //左孩子节点
    node* rightchild;       //右孩子节点
    string cond;        //分枝条件
    string decision;        //在该节点上作出的类别判定
    double precision;       //判定的正确率
    int record_number;      //该节点上涵盖的记录个数
    int size;       //子树包含的叶子节点的数目
    int index;      //层次遍历树,给节点标上序号
    double alpha;   //表面误差率的增加量
    node(){
        parent=NULL;
        leftchild=NULL;
        rightchild=NULL;
        precision=0.0;
        record_number=0;
        size=1;
        index=0;
        alpha=1.0;
    }
    node(node* p){
        parent=p;
        leftchild=NULL;
        rightchild=NULL;
        precision=0.0;
        record_number=0;
        size=1;
        index=0;
        alpha=1.0;
    }
    node(node* p,string c,string d):cond(c),decision(d){
        parent=p;
        leftchild=NULL;
        rightchild=NULL;
        precision=0.0;
        record_number=0;
        size=1;
        index=0;
        alpha=1.0;
    }
    void printInfo(){
        cout<<"index:"<<index<<"\tdecisoin:"<<decision<<"\tprecision:"<<precision<<"\tcondition:"<<cond<<"\tsize:"<<size;
        if(parent!=NULL)
            cout<<"\tparent index:"<<parent->index;
        if(leftchild!=NULL)
            cout<<"\tleftchild:"<<leftchild->index<<"\trightchild:"<<rightchild->index;
        cout<<endl;
    }
    void printTree(){
        printInfo();
        if(leftchild!=NULL)
            leftchild->printTree();
        if(rightchild!=NULL)
            rightchild->printTree();
    }
};
 
int readInput(string filename){
    ifstream ifs(filename.c_str());
    if(!ifs){
        cerr<<"open inputfile failed!"<<endl;
        return -1;
    }
    map<string,int> catg;
    string line;
    getline(ifs,line);
    string item;
    istringstream strstm(line);
    strstm>>item;
    for(int i=0;i<X.size();++i){
        strstm>>item;
        X[i]=item;
    }
    while(getline(ifs,line)){
        vector<string> conts(ATTR_NUM+2);
        istringstream strstm(line);
        //strstm.str(line);
        for(int i=0;i<conts.size();++i){
            strstm>>item;
            conts[i]=item;
            if(i==conts.size()-1)
                catg[item]++;
        }
        inputData.push_back(conts);
    }
    total_record_number=inputData.size();
    ifs.close();
    map<string,int>::const_iterator itr=catg.begin();
    while(itr!=catg.end()){
        classes.push_back(make_pair(itr->first,itr->second));
        itr++;
    }
    rest_number=classes.size();
    return 0;
}
 
/*根据inputData作出一个统计stati*/
void statistic(vector<vector<string> > &inputData,VEC_STATI &stati){
    for(int i=1;i<ATTR_NUM+1;++i){
        MAP_ATTR_REST attr_rest;
        for(int j=0;j<inputData.size();++j){
            string attr_value=inputData[j][i];
            string rest=inputData[j][ATTR_NUM+1];
            MAP_ATTR_REST::iterator itr=attr_rest.find(attr_value);
            if(itr==attr_rest.end()){
                MAP_REST_COUNT rest_count;
                rest_count[rest]=1;
                attr_rest[attr_value]=rest_count;
            }
            else{
                MAP_REST_COUNT::iterator iter=(itr->second).find(rest);
                if(iter==(itr->second).end()){
                    (itr->second).insert(make_pair(rest,1));
                }
                else{
                    iter->second+=1;
                }
            }
        }
        stati.push_back(attr_rest);
    }
}
 
/*依据某条件作出分枝时,inputData被分成两部分*/
void splitInput(vector<vector<string> > &inputData,int fitIndex,string cond,vector<vector<string> > &LinputData,vector<vector<string> > &RinputData){
    for(int i=0;i<inputData.size();++i){
        if(inputData[i][fitIndex+1]==cond)
            LinputData.push_back(inputData[i]);
        else
            RinputData.push_back(inputData[i]);
    }
}
 
void printStati(VEC_STATI &stati){
    for(int i=0;i<stati.size();i++){
        MAP_ATTR_REST::const_iterator itr=stati[i].begin();
        while(itr!=stati[i].end()){
            cout<<itr->first;
            MAP_REST_COUNT::const_iterator iter=(itr->second).begin();
            while(iter!=(itr->second).end()){
                cout<<"\t"<<iter->first<<"\t"<<iter->second;
                iter++;
            }
            itr++;
            cout<<endl;
        }
        cout<<endl;
    }
}
 
void split(node *root,vector<vector<string> > &inputData,vector<pair<string,int> > classes){
    //root->printInfo();
    root->record_number=inputData.size();
    VEC_STATI stati;
    statistic(inputData,stati);
    //printStati(stati);
    //for(int i=0;i<rest_number;i++)
    //  cout<<classes[i].first<<"\t"<<classes[i].second<<"\t";
    //cout<<endl;
    /*找到最大化GINI指标的划分*/
    double minGain=1.0;     //最小的GINI增益
    int fitIndex=-1;
    string fitCond;
    vector<pair<string,int> > fitleftclasses;
    vector<pair<string,int> > fitrightclasses;
    int fitleftnumber;
    int fitrightnumber;
    for(int i=0;i<stati.size();++i){     //扫描每一个自变量
        MAP_ATTR_REST::const_iterator itr=stati[i].begin();
        while(itr!=stati[i].end()){         //扫描自变量上的每一个取值
            string condition=itr->first;     //判定的条件,即到达左孩子的条件
            //cout<<"cond 为"<<X[i]+condition<<"时:";
            vector<pair<string,int> > leftclasses(classes);     //左孩子节点上类别、及对应的数目
            vector<pair<string,int> > rightclasses(classes);    //右孩子节点上类别、及对应的数目
            int leftnumber=0;       //左孩子节点上包含的类别数目
            int rightnumber=0;      //右孩子节点上包含的类别数目
            for(int j=0;j<leftclasses.size();++j){       //更新类别对应的数目
                string rest=leftclasses[j].first;
                MAP_REST_COUNT::const_iterator iter2;
                iter2=(itr->second).find(rest);
                if(iter2==(itr->second).end()){      //没找到
                    leftclasses[j].second=0;
                    rightnumber+=rightclasses[j].second;
                }
                else{       //找到
                    leftclasses[j].second=iter2->second;
                    leftnumber+=leftclasses[j].second;
                    rightclasses[j].second-=(iter2->second);
                    rightnumber+=rightclasses[j].second;
                }
            }
            /**if(leftnumber==0 || rightnumber==0){
                cout<<"左右有一边为空"<<endl;
                 
                for(int k=0;k<rest_number;k++)
                    cout<<leftclasses[k].first<<"\t"<<leftclasses[k].second<<"\t";
                cout<<endl;
                for(int k=0;k<rest_number;k++)
                    cout<<rightclasses[k].first<<"\t"<<rightclasses[k].second<<"\t";
                cout<<endl;
            }**/
            double gain1=1.0;           //计算GINI增益
            double gain2=1.0;
            if(leftnumber==0)
                gain1=0.0;
            else
                for(int j=0;j<leftclasses.size();++j)       
                    gain1-=pow(1.0*leftclasses[j].second/leftnumber,2.0);
            if(rightnumber==0)
                gain2=0.0;
            else
                for(int j=0;j<rightclasses.size();++j)
                    gain2-=pow(1.0*rightclasses[j].second/rightnumber,2.0);
            double gain=1.0*leftnumber/(leftnumber+rightnumber)*gain1+1.0*rightnumber/(leftnumber+rightnumber)*gain2;
            //cout<<"GINI增益:"<<gain<<endl;
            if(gain<minGain){
                //cout<<"GINI增益:"<<gain<<"\t"<<i<<"\t"<<condition<<endl;
                fitIndex=i;
                fitCond=condition;
                fitleftclasses=leftclasses;
                fitrightclasses=rightclasses;
                fitleftnumber=leftnumber;
                fitrightnumber=rightnumber;
                minGain=gain;
            }
            itr++;
        }
    }
 
    /*计算卡方值,看有没有必要进行分裂*/
    //cout<<"按"<<X[fitIndex]+fitCond<<"划分,计算卡方"<<endl;
    int **arr=new int*[2];
    for(int i=0;i<2;i++)
        arr[i]=new int[rest_number];
    for(int i=0;i<rest_number;i++){
        arr[0][i]=fitleftclasses[i].second;
        arr[1][i]=fitrightclasses[i].second;
    }
    double chi=cal_chi(arr,2,rest_number);
    //cout<<"chi="<<chi<<" CHI="<<CHI[rest_number-2]<<endl;
    if(chi<CHI[rest_number-2]){      //独立,没必要再分裂了
        delete []arr[0];    delete []arr[1];    delete []arr;
        return;             //不需要分裂函数就返回
    }
    delete []arr[0];    delete []arr[1];    delete []arr;
     
    /*分裂*/
    root->cond=X[fitIndex]+"="+fitCond;      //root的分枝条件
    //cout<<"分类条件:"<<root->cond<<endl;
    node *travel=root;      //root及其祖先节点的size都要加1
    while(travel!=NULL){
        (travel->size)++;
        travel=travel->parent;
    }
     
    node *LChild=new node(root);        //创建左右孩子
    node *RChild=new node(root);
    root->leftchild=LChild;
    root->rightchild=RChild;
    int maxLcount=0;
    int maxRcount=0;
    string Ldicision,Rdicision;
    for(int i=0;i<rest_number;++i){      //统计哪种类别出现的最多,从而作出类别判定
        if(fitleftclasses[i].second>maxLcount){
            maxLcount=fitleftclasses[i].second;
            Ldicision=fitleftclasses[i].first;
        }
        if(fitrightclasses[i].second>maxRcount){
            maxRcount=fitrightclasses[i].second;
            Rdicision=fitrightclasses[i].first;
        }
    }
    LChild->decision=Ldicision;
    RChild->decision=Rdicision;
    LChild->precision=1.0*maxLcount/fitleftnumber;
    RChild->precision=1.0*maxRcount/fitrightnumber;
     
    /*递归对左右孩子进行分裂*/
    vector<vector<string> > LinputData,RinputData;
    splitInput(inputData,fitIndex,fitCond,LinputData,RinputData);
    //cout<<"左边inputData行数:"<<LinputData.size()<<endl;
    //cout<<"右边inputData行数:"<<RinputData.size()<<endl;
    split(LChild,LinputData,fitleftclasses);
    split(RChild,RinputData,fitrightclasses);
}
 
/*计算子树的误差代价*/
double calR2(node *root){
    if(root->leftchild==NULL)
        return (1-root->precision)*root->record_number/total_record_number;
    else
        return calR2(root->leftchild)+calR2(root->rightchild);
}
 
/*层次遍历树,给节点标上序号。同时计算alpha*/
void index(node *root,priority_queue<MyTriple> &pq){
    int i=1;
    queue<node*> que;
    que.push(root);
    while(!que.empty()){
        node* n=que.front();
        que.pop();
        n->index=i++;
        if(n->leftchild!=NULL){
            que.push(n->leftchild);
            que.push(n->rightchild);
            //计算表面误差率的增量
            double r1=(1-n->precision)*n->record_number/total_record_number;      //节点的误差代价
            double r2=calR2(n);
            n->alpha=(r1-r2)/(n->size-1);
            pq.push(MyTriple(n->alpha,n->size,n->index));
        }
    }
}
 
/*剪枝*/
void prune(node *root,priority_queue<MyTriple> &pq){
    MyTriple triple=pq.top();
    int i=triple.third;
    queue<node*> que;
    que.push(root);
    while(!que.empty()){
        node* n=que.front();
        que.pop();
        if(n->index==i){
            cout<<"将要剪掉"<<i<<"的左右子树"<<endl;
            n->leftchild=NULL;
            n->rightchild=NULL;
            int s=n->size-1;
            node *trav=n;
            while(trav!=NULL){
                trav->size-=s;
                trav=trav->parent;
            }
            break;
        }
        else if(n->leftchild!=NULL){
            que.push(n->leftchild);
            que.push(n->rightchild);
        }
    }
}
 
void test(string filename,node *root){
    ifstream ifs(filename.c_str());
    if(!ifs){
        cerr<<"open inputfile failed!"<<endl;
        return;
    }
    string line;
    getline(ifs,line);
    string item;
    istringstream strstm(line);     //跳过第一行
    map<string,string> independent;       //自变量,即分类的依据
    while(getline(ifs,line)){
        istringstream strstm(line);
        //strstm.str(line);
        strstm>>item;
        cout<<item<<"\t";
        for(int i=0;i<ATTR_NUM;++i){
            strstm>>item;
            independent[X[i]]=item;
        }
        node *trav=root;
        while(trav!=NULL){
            if(trav->leftchild==NULL){
                cout<<(trav->decision)<<"\t置信度:"<<(trav->precision)<<endl;;
                break;
            }
            string cond=trav->cond;
            string::size_type pos=cond.find("=");
            string pre=cond.substr(0,pos);
            string post=cond.substr(pos+1);
            if(independent[pre]==post)
                trav=trav->leftchild;
            else
                trav=trav->rightchild;
        }
    }
    ifs.close();
}
 
int main(){
    string inputFile="animal";
    readInput(inputFile);
    VEC_STATI stati;        //最原始的统计
    statistic(inputData,stati);
     
//  for(int i=0;i<classes.size();++i)
//      cout<<classes[i].first<<"\t"<<classes[i].second<<"\t";
//  cout<<endl;
    node *root=new node();
    split(root,inputData,classes);      //分裂根节点
    priority_queue<MyTriple> pq;
    index(root,pq);
    root->printTree();
    cout<<"剪枝前使用该决策树最多进行"<<root->size-1<<"次条件判断"<<endl;
    /**
    //检验一个是不是表面误差增量最小的被剪掉了
    while(!pq.empty()){
        MyTriple triple=pq.top();
        pq.pop();
        cout<<triple.first<<"\t"<<triple.second<<"\t"<<triple.third<<endl;
    }
    **/
    test(inputFile,root);
     
    prune(root,pq);
    cout<<"剪枝后使用该决策树最多进行"<<root->size-1<<"次条件判断"<<endl;
    test(inputFile,root);
    return 0;
}
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个简单的使用 CART 决策树算法的 Java 代码示例: ```java import java.util.ArrayList; import java.util.HashMap; public class CART { public static void main(String[] args) { // 构造训练集数据 ArrayList<HashMap<String, String>> trainData = new ArrayList<>(); HashMap<String, String> data1 = new HashMap<>(); data1.put("age", "青年"); data1.put("income", "高"); data1.put("student", "否"); data1.put("credit_rating", "一般"); data1.put("class", "不放贷"); trainData.add(data1); HashMap<String, String> data2 = new HashMap<>(); data2.put("age", "青年"); data2.put("income", "高"); data2.put("student", "否"); data2.put("credit_rating", "好"); data2.put("class", "不放贷"); trainData.add(data2); HashMap<String, String> data3 = new HashMap<>(); data3.put("age", "中年"); data3.put("income", "高"); data3.put("student", "否"); data3.put("credit_rating", "好"); data3.put("class", "放贷"); trainData.add(data3); HashMap<String, String> data4 = new HashMap<>(); data4.put("age", "中年"); data4.put("income", "中等"); data4.put("student", "否"); data4.put("credit_rating", "好"); data4.put("class", "放贷"); trainData.add(data4); HashMap<String, String> data5 = new HashMap<>(); data5.put("age", "中年"); data5.put("income", "中等"); data5.put("student", "是"); data5.put("credit_rating", "一般"); data5.put("class", "放贷"); trainData.add(data5); HashMap<String, String> data6 = new HashMap<>(); data6.put("age", "老年"); data6.put("income", "中等"); data6.put("student", "是"); data6.put("credit_rating", "好"); data6.put("class", "放贷"); trainData.add(data6); HashMap<String, String> data7 = new HashMap<>(); data7.put("age", "老年"); data7.put("income", "低"); data7.put("student", "是"); data7.put("credit_rating", "好"); data7.put("class", "不放贷"); trainData.add(data7); HashMap<String, String> data8 = new HashMap<>(); data8.put("age", "老年"); data8.put("income", "低"); data8.put("student", "否"); data8.put("credit_rating", "一般"); data8.put("class", "不放贷"); trainData.add(data8); // 训练决策树模型 DecisionTreeModel model = train(trainData); System.out.println("决策树模型:" + model); // 预测新数据 HashMap<String, String> newData = new HashMap<>(); newData.put("age", "青年"); newData.put("income", "中等"); newData.put("student", "否"); newData.put("credit_rating", "一般"); String result = predict(newData, model); System.out.println("新数据预测结果:" + result); } /** * 训练决策树模型 * @param trainData 训练集数据 * @return 决策树模型 */ public static DecisionTreeModel train(ArrayList<HashMap<String, String>> trainData) { // 获取训练集属性列表 ArrayList<String> attributeList = new ArrayList<>(); for (String key : trainData.get(0).keySet()) { attributeList.add(key); } // 构建决策树模型 DecisionTreeModel model = new DecisionTreeModel(); buildDecisionTree(trainData, attributeList, model); return model; } /** * 构建决策树 * @param trainData 训练集数据 * @param attributeList 属性列表 * @param model 决策树模型 */ public static void buildDecisionTree(ArrayList<HashMap<String, String>> trainData, ArrayList<String> attributeList, DecisionTreeModel model) { // 如果训练集中所有数据属于同一类别,则将当前节点设置为叶子节点,并返回 boolean isSameClass = true; String firstClass = trainData.get(0).get("class"); for (HashMap<String, String> data : trainData) { if (!data.get("class").equals(firstClass)) { isSameClass = false; break; } } if (isSameClass) { model.isLeaf = true; model.className = firstClass; return; } // 如果属性列表为空,则将当前节点设置为叶子节点,并将其类别设置为训练集中最常见的类别 if (attributeList.isEmpty()) { model.isLeaf = true; model.className = getMostCommonClass(trainData); return; } // 选择最佳属性(即使得信息增益最大的属性) String bestAttribute = getBestAttribute(trainData, attributeList); model.attributeName = bestAttribute; // 根据最佳属性分裂训练集 ArrayList<ArrayList<HashMap<String, String>>> splitDataList = splitData(trainData, bestAttribute); // 递归构建子树 ArrayList<String> newAttributeList = new ArrayList<>(attributeList); newAttributeList.remove(bestAttribute); // 在属性列表中删除已经使用的属性 for (ArrayList<HashMap<String, String>> splitData : splitDataList) { DecisionTreeModel subModel = new DecisionTreeModel(); model.subModelList.add(subModel); buildDecisionTree(splitData, newAttributeList, subModel); } } /** * 预测新数据 * @param newData 新数据 * @param model 决策树模型 * @return 预测结果 */ public static String predict(HashMap<String, String> newData, DecisionTreeModel model) { // 如果当前节点是叶子节点,则返回其类别 if (model.isLeaf) { return model.className; } // 根据当前节点的属性进行分裂 String attributeValue = newData.get(model.attributeName); for (DecisionTreeModel subModel : model.subModelList) { if (subModel.attributeValue.equals(attributeValue)) { return predict(newData, subModel); } } // 如果当前节点没有与新数据匹配的子节点,则将其类别设置为训练集中最常见的类别 return getMostCommonClass(model.trainData); } /** * 获取训练集中最常见的类别 * @param trainData 训练集数据 * @return 最常见的类别 */ public static String getMostCommonClass(ArrayList<HashMap<String, String>> trainData) { HashMap<String, Integer> classCountMap = new HashMap<>(); for (HashMap<String, String> data : trainData) { String className = data.get("class"); if (classCountMap.containsKey(className)) { classCountMap.put(className, classCountMap.get(className) + 1); } else { classCountMap.put(className, 1); } } String mostCommonClass = ""; int maxCount = -1; for (String className : classCountMap.keySet()) { int count = classCountMap.get(className); if (count > maxCount) { mostCommonClass = className; maxCount = count; } } return mostCommonClass; } /** * 获取训练集中最佳属性 * @param trainData 训练集数据 * @param attributeList 属性列表 * @return 最佳属性 */ public static String getBestAttribute(ArrayList<HashMap<String, String>> trainData, ArrayList<String> attributeList) { String bestAttribute = ""; double maxInformationGain = -1; for (String attribute : attributeList) { double informationGain = calculateInformationGain(trainData, attribute); if (informationGain > maxInformationGain) { bestAttribute = attribute; maxInformationGain = informationGain; } } return bestAttribute; } /** * 根据指定属性值分裂训练集 * @param trainData 训练集数据 * @param attributeName 属性名称 * @return 分裂后的数据集列表 */ public static ArrayList<ArrayList<HashMap<String, String>>> splitData(ArrayList<HashMap<String, String>> trainData, String attributeName) { ArrayList<ArrayList<HashMap<String, String>>> splitDataList = new ArrayList<>(); for (HashMap<String, String> data : trainData) { String attributeValue = data.get(attributeName); boolean isSplitDataExist = false; for (ArrayList<HashMap<String, String>> splitData : splitDataList) { if (splitData.get(0).get(attributeName).equals(attributeValue)) { splitData.add(data); isSplitDataExist = true; break; } } if (!isSplitDataExist) { ArrayList<HashMap<String, String>> newSplitData = new ArrayList<>(); newSplitData.add(data); splitDataList.add(newSplitData); } } for (ArrayList<HashMap<String, String>> splitData : splitDataList) { if (splitData.size() > 0) { String attributeValue = splitData.get(0).get(attributeName); DecisionTreeModel subModel = new DecisionTreeModel(); subModel.attributeName = attributeName; subModel.attributeValue = attributeValue; subModel.trainData = splitData; } } return splitDataList; } /** * 计算指定属性的信息增益 * @param trainData 训练集数据 * @param attributeName 属性名称 * @return 信息增益 */ public static double calculateInformationGain(ArrayList<HashMap<String, String>> trainData, String attributeName) { // 计算训练集的熵 double entropy = calculateEntropy(trainData); // 计算分裂后的熵 double splitEntropy = 0; ArrayList<ArrayList<HashMap<String, String>>> splitDataList = splitData(trainData, attributeName); for (ArrayList<HashMap<String, String>> splitData : splitDataList) { double splitDataEntropy = calculateEntropy(splitData); double splitDataProbability = (double) splitData.size() / trainData.size(); splitEntropy += splitDataEntropy * splitDataProbability; } // 计算信息增益 double informationGain = entropy - splitEntropy; return informationGain; } /** * 计算数据集的熵 * @param dataList 数据集 * @return 熵 */ public static double calculateEntropy(ArrayList<HashMap<String, String>> dataList) { HashMap<String, Integer> classCountMap = new HashMap<>(); for (HashMap<String, String> data : dataList) { String className = data.get("class"); if (classCountMap.containsKey(className)) { classCountMap.put(className, classCountMap.get(className) + 1); } else { classCountMap.put(className, 1); } } double entropy = 0; for (String className : classCountMap.keySet()) { double probability = (double) classCountMap.get(className) / dataList.size(); entropy -= probability * Math.log(probability) / Math.log(2); } return entropy; } } /** * 决策树模型 */ class DecisionTreeModel { public boolean isLeaf; // 是否是叶子节点 public String attributeName; // 分裂属性名称 public String attributeValue; // 分裂属性值 public ArrayList<DecisionTreeModel> subModelList; // 子模型列表 public String className; // 类别名称 public ArrayList<HashMap<String, String>> trainData; // 训练集数据 public DecisionTreeModel() { this.isLeaf = false; this.attributeName = ""; this.attributeValue = ""; this.subModelList = new ArrayList<>(); this.className = ""; this.trainData = new ArrayList<>(); } public String toString() { StringBuilder sb = new StringBuilder(); if (isLeaf) { sb.append(className); } else { sb.append(attributeName + " -> "); for (DecisionTreeModel subModel : subModelList) { sb.append(subModel.attributeValue + ": " + subModel.toString() + "; "); } } return sb.toString(); } } ``` 这个示例代码实现了一个简化的 CART 决策树算法,并提供了训练和预测的方法。由于数据集比较小,所以没有进行剪枝等优化操作。在实际应用中,可以根据具体情况进行改进。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值