决策树的训练过程

决策树--------以分类有毒和无毒蘑菇为例

 

这里的原始数据是agaricus-lepiota.data文件。每一行代表一个样本,每一列代表一个特征,这里一共8024个样本,每个样本有22个特征。其中第一列是样本标签,标示有毒(p)和无毒(e)。除此之外,还有22列,代表22个特征。有些特征缺失,用?表示。下面是一部分数据的例子。

p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u

e,x,s,y,t,a,f,c,b,k,e,c,s,s,w,w,p,w,o,p,n,n,g

e,b,s,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,n,m

p,x,y,w,t,p,f,c,n,n,e,e,s,s,w,w,p,w,o,p,k,s,u

e,x,s,g,f,n,f,w,b,k,t,e,s,s,w,w,p,w,o,e,n,a,g

e,x,y,y,t,a,f,c,b,n,e,c,s,s,w,w,p,w,o,p,k,n,g

e,b,s,w,t,a,f,c,b,g,e,c,s,s,w,w,p,w,o,p,k,n,m

e,b,y,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,s,m

p,x,y,w,t,p,f,c,n,p,e,e,s,s,w,w,p,w,o,p,k,v,g

 

 

将这些数据创建成特征向量数据data(8024*22),标签数据responses(8024*1),丢失特征数据missing(8024*22)。用下面接口实现训练的过程。

dtree = new CvDTree;

tree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,

                  CvDTreeParams( 8, // max depth

                                 10, // min sample count

                                 0, // regression accuracy: N/A here

                                 true, // compute surrogate split, as we have missing data

                                 15, // max number of categories (use sub-optimal algorithm for larger numbers)

                                 10, // the number of cross-validation folds

                                 true, // use 1SE rule => smaller tree

                                 true, // throw away the pruned tree branches

                                 priors // the array of priors, the bigger p_weight, the more attention

                                        // to the poisonous mushrooms

                                        // (a mushroom will be judjed to be poisonous with bigger chance)

                                 ));

 

 

 

 

 

 

下面详细讲解一下训练的过程。

(1)train核心是使用do_train函数

bool CvDTree::do_train( const CvMat* _subsample_idx )

 

(2)do_train函数核心是使用try_split_node函数

void CvDTree::try_split_node( CvDTreeNode* node )

{

    CvDTreeSplit* best_split = 0;

    int i, n = node->sample_count, vi;

    bool can_split = true;

    double quality_scale;

 

calc_node_value( node );//修改CvDTree中的成员变量CvDTreeTrainData* data

//统计每一类的样本个数,并将其写入data的counts成员变量中,counts为矩阵

 

//节点的样本个数小于参数的最小样本数或者节点的深度大于最大深度,此函数返回

//否则,不断进行分裂

    if( node->sample_count <= data->params.min_sample_count ||

        node->depth >= data->params.max_depth )

        can_split = false;

 

    if( can_split && data->is_classifier )

    {

        // check if we have a "pure" node,

        // we assume that cls_count is filled by calc_node_value()

//如nz=1,则该节点纯净度为0,即不再需要分裂

int* cls_count = data->counts->data.i;

        int nz = 0, m = data->get_num_classes();

        for( i = 0; i < m; i++ )

            nz += cls_count[i] != 0;

        if( nz == 1 ) // there is only one class

            can_split = false;

    }

    else if( can_split )

    {

        if( sqrt(node->node_risk)/n < data->params.regression_accuracy )

            can_split = false;

    }

 

    if( can_split )

    {

        best_split = find_best_split(node);//find_best_split是核心

        // TODO: check the split quality ...

        node->split = best_split;

    }

    if( !can_split || !best_split )

    {

        data->free_node_data(node);

        return;

    }

 

    quality_scale = calc_node_dir( node );

    if( data->params.use_surrogates )

    {

        // find all the surrogate splits

        // and sort them by their similarity to the primary one

       。。。。。。。缺失特征的处理

    }

    split_node_data( node );//根据最优分裂的特征进行分裂

    try_split_node( node->left );//左分支继续递归

    try_split_node( node->right );//右分支继续递归

}

 

      try_split_node函数本质就是不断递归的过程,不断对新分裂的左右节点分别继续分裂,直至节点纯度为0,或者节点的样本数达到最小标准或者节点的深度达到最大深度的条件。

针对每个节点分裂为两个子节点时,总是要寻找最优的分裂组合,即find_best_split实现这个过程。以蘑菇为例,该函数找出当前节点下,选取22个特征的最优组合,每个特征选出其取值的最优组合。

 

(3)try_split_node函数重点是使用find_best_split函数,下面着重介绍一下find_best_split函数,find_best_split函数通过下面这句命令,实现此过程。

cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);

 

实际上,parallel_reduce是调用body函数,其中重载了操作符()。即实际上是执行下面的函数。

void DTreeBestSplitFinder::operator()(const BlockedRange& range)

{

int vi, vi1 = range.begin(), vi2 = range.end();//上面的data->var_count为特

//征的总个数,对应于蘑菇的例子,特征总数为22,所以这里的vi1= 0,vi2 =22

int n = node->sample_count;//n为node中样本的个数,蘑菇例子中,如果是根节点,

//n为8124,下一级左边节点为4364

    CvDTreeTrainData* data = tree->get_data();

    AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));

 

    //蘑菇的22个特征依次循环,找到最优的分裂特征

    for( vi = vi1; vi < vi2; vi++ )

    {

        CvDTreeSplit *res;

        int ci = data->get_var_type(vi);

        if( node->get_num_valid(vi) <= 1 )

            continue;

 

        if( data->is_classifier )

        {

            if( ci >= 0 )

                res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );

           。。。。。//分类回归调用不同的接口,这里先只看分类

        }

 

// bestSplit为DTreeBestSplitFinder的成员变量,保存最优的分裂特征

        if( res && bestSplit->quality < split->quality )

                memcpy( (CvDTreeSplit*)bestSplit, (CvDTreeSplit*)split, splitSize );

    }

}

 

4重点是find_split_cat_class函数,该函数能够找到指定的某个特征vi对应的最优分裂的值域组合。比如指定的特征vi对应着颜色特征,该特征的值域有红,黄,蓝三个,find_split_cat_class就会找出最优的值域组合,使得分裂后不纯度最小,如(红,蓝)为左边分支,黄为右边分支。

 

CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, 

float init_quality,CvDTreeSplit* _split, uchar* _ext_buf )

{

    。。。。。

    int m = data->get_num_classes();//m 为分类的种类,如有毒和无毒两种蘑菇,m=2

    int _mi = data->cat_count->data.i[ci], mi = _mi;//mi为特征vi对应值的种类个数,//如蘑菇的颜色特征有红黄蓝绿紫黑六种,则mi=6

   。。。。。。

    //以下均以2类问题为例

    int* lc = (int*)base_buf;//lc为1*2的数组,存储左边分支的正负样本数目

    int* rc = lc + m; //rc为1*2的数组,存储右边分支的正负样本数目

int* _cjk = rc + m*2, *cjk = _cjk; //cjk为6*2的数组,存储所有样本按照颜色特

//征取不同值时,相应的正负样本数目。

 

// c_weights为6*1的数组,存储所有样本按照颜色特征取不同值时,相应的样本权重

//个数。即为cjk数组的每一行求和,即某个特征的某个特征取值对应的正负样本个数

//分别与其权重乘积之后求和

double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double)); 

 

// labels为所有样本颜色特征的对应值,为8124的数组,数组的每个值对应一个样本

//相应的颜色取值,如0~5中的一个

// responses为所有样本对应的正负标签,为8124的数组,数组的每个值对应一个样本

//为正样本,还是负样本

    int* labels_buf = (int*)ext_buf;

    const int* labels = data->get_cat_var_data(node, vi, labels_buf);

    int* responses_buf = labels_buf + n;

    const int* responses = data->get_class_labels(node, responses_buf);

 

   。。。。。。。。。。。

    double L = 0, R = 0;

    double best_val = init_quality;

    int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;

    const double* priors = data->priors_mult->data.db;

 

    // init array of counters:

    // c_{jk} - number of samples that have vi-th input variable = j and response //= k.

//cjk为6*2的数组,存储所有样本按照颜色特征取不同值时,相应的正负样本数目。

//即特征vi取值为j时,对应的k样本的数目(k为正或者负两种取值)

 

    for( j = -1; j < mi; j++ )

        for( k = 0; k < m; k++ )

            cjk[j*m + k] = 0;

 

    for( i = 0; i < n; i++ )

    {

       j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];

       k = responses[i];

       cjk[j*m + k]++;

    }

 

    if( m > 2 )

    {

       。。。。。。。。聚类

    }

    else

{

        //将int_ptr按照样本数目大小排序,数目大的排在上面,以便后面循环找特征值//,按照int_ptr的顺序寻找最优分裂,int_ptr的某一行对应着cjk的某一行

        assert( m == 2 );

        int_ptr = (int**)(c_weights + _mi);

        for( j = 0; j < mi; j++ )

int_ptr[j] = cjk + j*2 + 1;// int_ptr为cjk矩阵中的行指针+1,即指向

//负样本数目,假设cjk第一列为正样本数目,第二列为负样本数目

        icvSortIntPtr( int_ptr, mi, 0 );

        subset_i = 0;

        subset_n = mi;//寻找最优值域需要的尝试次数,次数为值域的个数,即为6

    }

 

//lc为1*2的数组,存储右边分支的正负样本数目,这里初始化为0,

//后面每取出颜色特//征的一个某取值时,lc将加上其正负样本数目,

//即lc为选用颜色特征的某个值域时,左边分支的正负样本数目,

//左边分支的样本的颜色特征是属于这个值域的。

 

//rc为1*2的数组,存储右边分支的正负样本数目,这里分别初始化为所有的正样本数目//和负样本数目,后面每取出颜色特征的一个某取值时,rc将减去其正负样本数目

//即rc为选用颜色特征的某个取值时,右边分支的正负样本数目

//右边分支的样本的颜色特征是不属于这个值域的。

    for( k = 0; k < m; k++ )

    {

        int sum = 0;

        for( j = 0; j < mi; j++ )

            sum += cjk[j*m + k];

        rc[k] = sum;

        lc[k] = 0;

    }

 

 

// c_weights为6*1的数组,存储所有样本按照颜色特征取不同值时,相应的样本权重//个数。即为cjk数组的每一行求和,即某个特征的某个特征取值对应的正负样本个数/ //分别与其权重乘积之后求和

 

// L存储左边分支的正负样本个数与其权重乘积之后的累积和,这里初始化为0

//后面每取出颜色特征的一个某取值时,L将加上其对应的样本权重数目,即加上//c_weights的对应值

//即L为选用颜色特征的某个取值时,左边分支的样本权重数目

//左边分支的样本的颜色特征是属于这个值域的。

 

// R存储右边分支的正负样本个数与其权重乘积之后的累积和,这里初始化为

//c_weights的累积和,即为所有样本的个数和其权重乘积之和

//后面每取出颜色特征的一个某取值时,R将减去其对应的样本权重数目,即减去//c_weights的对应值

//即R为选用颜色特征的某个取值时,右边分支的样本权重数目

//右边分支的样本的颜色特征是不属于这个值域的。

    for( j = 0; j < mi; j++ )

    {

        double sum = 0;

        for( k = 0; k < m; k++ )

            sum += cjk[j*m + k]*priors[k];

        c_weights[j] = sum;

        R += c_weights[j];

    }

 

    //6次尝试,寻找最优值域组合,6为颜色特征的取值种类数。

//理论上讲,应该进行2的6次方减2次尝试,但是2分类问题,可以采用这种方法简化

    for( ; subset_i < subset_n; subset_i++ )

    {

        double weight;

        int* crow;

        double lsum2 = 0, rsum2 = 0;

 

        if( m == 2 )

            idx = (int)(int_ptr[subset_i] - cjk)/2; //idx为cjk的某一行,这里为0~5

        else

        {

            。。。。。。

        }

 

        crow = cjk + idx*m;  //crow为cjk的行指针,cjk的某一行对应着指定特征的

//某一个取值,如0行代表红色的正负样本数,1行代表蓝色的正负样本数,等等

        weight = c_weights[idx];//特征的某个取值对应的正负样本数的权重累计和。

        if( weight < FLT_EPSILON )

            continue;

 

        //

        if( !subtract )

        {

            for( k = 0; k < m; k++ )

            {

int t = crow[k]; //crow为指定特征的某一个取值对应的

//正样本数,负样本数

int lval = lc[k] + t;// lval为指定特征的某一个取值之前的所有取

//值(包含该取值)对应的正样本个数,负样本个数

int rval = rc[k] - t; // rval为指定特征的某一个取值之前的所有取//值(包含该取值)对应的正样本个数,负样本个数的补集,即不满足指

//定的特征取值的正样本个数,负样本个数

double p = priors[k], p2 = p*p;//p为正样本的权重(1/11)和负样

//本的权重(10/11)

sum2 += p2*lval*lval;// lsum2为指定特征的某一个取值之前的所有

//取值(包含该取值)对应的正样本个数的权重累积和与负样本个数权重

//累积和之和

rsum2 += p2*rval*rval; // rsum2为指定特征的某一个取值之前的所有//取值(包含该取值)对应的正样本个数的权重累积和的补集与负样本个

//数权重累积和的补集只和

 

                lc[k] = lval; rc[k] = rval;

            }

            L += weight;

            R -= weight;

        }

        else

        {

           。。。。。

        }

 

        if( L > FLT_EPSILON && R > FLT_EPSILON )

        {

            //找出纯度最小,即val最大的分类

            double val = (lsum2*R + rsum2*L)/((double)L*R);

            if( best_val < val )

            {

                best_val = val;

                best_subset = subset_i;

            }

        }

    }

 

    CvDTreeSplit* split = 0;

    if( best_subset >= 0 )

    {

        split = _split ? _split : data->new_split_cat( 0, -1.0f );

        split->var_idx = vi;// split->var_idx用来标识该分裂使用哪个特征

        split->quality = (float)best_val;

        memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));

        //每个特征取值对应subset的一位,其对应关系和cjk中的取值排列顺序一致

        //每个int为32位,每个subset最多对应32个取值种类数

if( m == 2 )

        {

            //设置分裂的特征值域为best_subset之前对应的所有取值

            for( i = 0; i <= best_subset; i++ )

            {

                //idx是cjk排序后的特征值对应的顺序

idx = (int)(int_ptr[i] - cjk) >> 1;      

 

//1左移idx位( 当idx<32),即为2的idx次方

//即split->subset用每一位表示是否使用颜色特征的这个值,1表示

//使用,0表示未使用

//如果cjk的第一行为红色,第二行为蓝色,第三行为绿色,。。。。。

//则split->subset第一位为红色,第二位为蓝色,第三位为绿色。。。。

//这里位的顺序是按照从右往左计算的

split->subset[idx >> 5] |= 1 << (idx & 31);

            }

        }

        else

        {

            for( i = 0; i < _mi; i++ )

            {

                idx = cluster_labels ? cluster_labels[i] : i;

                if( best_subset & (1 << idx) )

                    split->subset[i >> 5] |= 1 << (i & 31);

            }

        }

    }

    return split;

}

 

(5)需要指出上面计算分裂不纯度是用下面的程序实现的:

double val = (lsum2*R + rsum2*L)/((double)L*R);

然后找出是val最大时对应的取值。

 

理论上讲,不纯度是使用下面的公式:

其中pi为某类样本出现的概率,二类问题分别对应着正样本的概率和负样本的概率。

最优分裂是不纯度最小时对应得分裂。

 

上面之所以求val最大值,即求的是不纯度公式中的求和这一项。

上面的公式是某个节点的不纯度,当分裂后,会变成两个节点,此时该特征的某值域对应的分裂的不纯度中求和项为:

 

 

由于某次分裂中,(L+R)是固定值,对应着分裂前所有正负样本的权重累积和。当寻找某个特征对应的最优分裂值域时,使用的样本是一样的,因此在某个节点的最优分裂的比较中,L+R是一个定值。因此计算的val如下:

double val = (lsum2*R + rsum2*L)/((double)L*R);

 


  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值