OpenCV(4)ML库->Decision Tree决策树

ID3, C4.5, C5.0, CART

ID3 1986年 Quilan

选择具有最高信息增益的属性作为测试属性

01 ID3(DataSet, featureList):
02   - 创建根节点R
03   - 如果当前DataSet中的数据都属于同一类,则标记R的类别为该类
04   - 如果当前featureList集合为空,则标记R的类别为当前DataSet中样本最多的类别
05   - 递归情况:
06     # 从featureList中选择属性(选择Gain(DataSet,F)最大的属性)
07     # 根据F的每一个值v,将DataSet划分为不同的子集DS,对每一个Ds:
08       - 创建节点C
09       - 如果DS为空,节点C标记为DataSet中样本最多的类别
10       - 如果DS不为空,节点C=ID3(DS,featureList-F)
11       - 将节点C添加为R的子节点


C4.5 1993年 by Quilab(对ID3的改进)

  • 信息增益率(information gain ratio)
  • 连续值属性

           离散化处理:将连续型的属性变量进行离散化处理,形成决策树的训练集

               - 把需要处理的样本按照连续变量的大小从小到大进行排序

               - 假设该属性对应的不同的属性值一共有N个,那么总共有N-1个可能的候选分割阈值点,每个候选的分割阈值点的值为上述排序后的属性值中两两前后元素的中点

               - 用信息增益率选择最佳划分

  • 缺失值
        - 处理缺少属性值的一种策略是赋给它节点t所对应的训练实例中该属性的最常见值


        - 复杂一点的办法是为每个可能值赋一个概率

        - 最简单的办法是丢弃这些样本

  • 后剪枝(基于错误剪枝EBP-Error Based Pruning)
    


C5.0 1998年

加入了Boosting算法框架

CART (Classification and Regression Trees)


  • 二元划分
        - 二叉树不易产生数据碎片,精确度往往会高于多叉树,所以在CART算法中采用二元划分  
  • 不纯性度量
        - 分类目标:Gini指标、Towing、order Towing


        - 连续目标:最小平方残差、最小绝对残差

  • 剪枝
        - 用独立的验证数据集对训练集生长的树进行剪枝


分类树:    

01 CART_classification(DataSet, featureList, alpha,):
02 创建根节点R
03 如果当前DataSet中的数据的类别相同,则标记R的类别标记为该类
04 如果决策树高度大于alpha,则不再分解,标记R的类别classify(DataSet)
05 递归情况:
06 标记R的类别classify(DataSet)
07 从featureList中选择属性F(选择Gini(DataSet, F)最小的属性划分,连续属性参考C4.5的离散化过程(以Gini最小作为划分标准))
08 根据F,将DataSet做二元划分DS_L 和 DS_R:
09 如果DS_L或DS_R为空,则不再分解
10 如果DS_L和DS_R都不为空,节点
11     C_L= CART_classification(DS_L, featureList, alpha);
12     C_R= CART_classification(DS_R featureList, alpha)
13 将节点C_L和C_R添加为R的左右子节点
回归树
01 CART_regression(DataSet, featureList, alpha, delta):
02 创建根节点R
03 如果当前DataSet中的数据的值都相同,则标记R的值为该值
04 如果最大的phi值小于设定阈值delta,则标记R的值为DataSet应变量均值
05 如果其中一个要产生的节点的样本数量小于alpha,则不再分解,标记R的值为DataSet应变量均值
06 递归情况:
07 从featureList中选择属性F(选择phi(DataSet, F)最大的属性,连续属性(或使用多个属性的线性组合)参考C4.5的离散化过程 (以phi最大作为划分标准))
08 根据F,将DataSet做二元划分DS_L 和 DS_R:
09 如果DS_L或DS_R为空,则标记节点R的值为DataSet应变量均值
10 如果DS_L和DS_R都不为空,节点
11     C_L= CART_regression(DS_L, featureList, alpha, delta);
12     C_R= CART_regression(DS_R featureList, alpha, delta)
13 将节点C_L和C_R添加为R的左右子节点
分类树与回归树的差别在于空间划分方法一个是线性一个是非线性

引自:http://m.blog.csdn.net/blog/android_asp/10218451


接下来就来学习OpenCV2.1.0 sample/c/mushroom.cpp里面的代码。

首先我们来分析主函数:

int main( int argc, char** argv )
{
    CvMat *data = 0, *missing = 0, *responses = 0;
    CvDTree* dtree;
    //读取样本数据
    const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";

    if( !mushroom_read_database( base_path, &data, &missing, &responses ) )
    {
        printf( "Unable to load the training database\n"
                "Pass it as a parameter: dtree <path to agaricus-lepiota.data>\n" );
        return 0;
        return -1;
    }
    //样本DT决策树构建
    dtree = mushroom_create_dtree( data, missing, responses,
        10 // poisonous mushrooms will have 10x higher weight in the decision tree
        );
    cvReleaseMat( &data );
    cvReleaseMat( &missing );
    cvReleaseMat( &responses );
    //蘑菇特征输入函数
    print_variable_importance( dtree, var_desc );
    //输入新样本数据决策函数
    interactive_classification( dtree, var_desc );
    delete dtree;

    return 0;
}
首先我们来分析第一个重要的函数,读取样本函数:

int mushroom_read_database( const char* filename, CvMat** data, CvMat** missing, CvMat** responses )
{
    const int M = 1024;
    FILE* f = fopen( filename, "rt" );
    CvMemStorage* storage;
    CvSeq* seq;
    char buf[M+2], *ptr;
    float* el_ptr;
    CvSeqReader reader;
    int i, j, var_count = 0;

    if( !f )
        return 0;

    // read the first line and determine the number of variables
    if( !fgets( buf, M, f ))
    {
        fclose(f);
        return 0;
    }

    for( ptr = buf; *ptr != '\0'; ptr++ )
        var_count += *ptr == ',';   //计算每个样本的数量,每个样本一个“,”,样本数量=var_count+1; 
    assert( ptr - buf == (var_count+1)*2 );

    // create temporary memory storage to store the whole database //把样本存入seq中,存储空间是storage; 
    el_ptr = new float[var_count+1];
    storage = cvCreateMemStorage();
    seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );

    for(;;)
    {
        for( i = 0; i <= var_count; i++ )
        {
            int c = buf[i*2];
            el_ptr[i] = c == '?' ? -1.f : (float)c;
        }
        if( i != var_count+1 )
            break;
        cvSeqPush( seq, el_ptr );
        if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
            break;
    }
    fclose(f);

    // allocate the output matrices and copy the base there
    *data = cvCreateMat( seq->total, var_count, CV_32F ); //样本数量,大小。
    *missing = cvCreateMat( seq->total, var_count, CV_8U );
    *responses = cvCreateMat( seq->total, 1, CV_32F );   //标志

    cvStartReadSeq( seq, &reader );

    for( i = 0; i < seq->total; i++ )
    {
        const float* sdata = (float*)reader.ptr + 1;
        float* ddata = data[0]->data.fl + var_count*i;
        float* dr = responses[0]->data.fl + i;
        uchar* dm = missing[0]->data.ptr + var_count*i;

        for( j = 0; j < var_count; j++ )
        {
            ddata[j] = sdata[j];
            dm[j] = sdata[j] < 0;
        }
        *dr = sdata[-1];
        CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    }

    cvReleaseMemStorage( &storage );
    delete el_ptr;
    return 1;
}
在上面的函数里,我们可能最关键的需要样本的格式范例:

const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";
从这行代码里,我们找到该文件,我做了些翻译:

蘑菇的特征向量分量的特征:
[
毒性
帽的形状
顶曲面
帽子的颜色
瘀伤,伤痕,擦伤
气味
菌褶 附件
菌褶 间隔
大小
颜色
形状
根
茎表面上的环
茎表面以下环
秆色以上环
下面的茎色环
面罩式
面纱的颜色
环数
环式
孢子印的颜色
特定[生物]种群
栖息地,住处; 产地
]
p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
e,x,s,y,t,a,f,c,b,k,e,c,s,s,w,w,p,w,o,p,n,n,g
e,b,s,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,n,m
p,x,y,w,t,p,f,c,n,n,e,e,s,s,w,w,p,w,o,p,k,s,u
e,x,s,g,f,n,f,w,b,k,t,e,s,s,w,w,p,w,o,e,n,a,g
e,x,y,y,t,a,f,c,b,n,e,c,s,s,w,w,p,w,o,p,k,n,g
e,b,s,w,t,a,f,c,b,g,e,c,s,s,w,w,p,w,o,p,k,n,m
e,b,y,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,s,m
p,x,y,w,t,p,f,c,n,p,e,e,s,s,w,w,p,w,o,p,k,v,g
e,b,s,y,t,a,f,c,b,g,e,c,s,s,w,w,p,w,o,p,k,s,m
e,x,y,y,t,l,f,c,b,g,e,c,s,s,w,w,p,w,o,p,n,n,g
e,x,y,y,t,a,f,c,b,n,e,c,s,s,w,w,p,w,o,p,k,s,m
e,b,s,y,t,a,f,c,b,w,e,c,s,s,w,w,p,w,o,p,n,s,g
这样样本的格式就一目了然,然后再去看读样本函数就容易些了。

接下来就是第二个最关键的函数:

CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,
                                const CvMat* responses, float p_weight )
{
    CvDTree* dtree;
    CvMat* var_type;
    int i, hr1 = 0, hr2 = 0, p_total = 0;
    float priors[] = { 1, p_weight };

    var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
    cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categorical

    dtree = new CvDTree;
    
    dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
                  CvDTreeParams( 8, // max depth
                                 10, // min sample count
                                 0, // regression accuracy: N/A here
                                 true, // 为真时,计算missing data和可变的重要性正确度  
                                 15, // max number of categories (use sub-optimal algorithm for larger numbers)
                                 10, // the number of cross-validation folds
                                 true, // use 1SE rule => smaller tree  修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确  
                                 true, // throw away the pruned tree branches
                                 priors // the array of priors, the bigger p_weight, the more attention
                                        // to the poisonous mushrooms
                                        // (a mushroom will be judjed to be poisonous with bigger chance)
                                 ));

    // compute hit-rate on the training database, demonstrates predict usage.
    /* 使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;   */
    for( i = 0; i < data->rows; i++ )
    {
        CvMat sample, mask;
        cvGetRow( data, &sample, i );
        cvGetRow( missing, &mask, i );
        double r = dtree->predict( &sample, &mask )->value;
        int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;//大于阈值FLT_EPSILON被判断为误检  
        if( d )
        {
            if( r != 'p' )
                hr1++;
            else
                hr2++;
        }
        p_total += responses->data.fl[i] == 'p';
    }

    printf( "Results on the training database:\n"
            "\tPoisonous mushrooms mis-predicted: %d (%g%%)\n"
            "\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,
            hr2, (double)hr2*100/(data->rows - p_total) );

    cvReleaseMat( &var_type );

    return dtree;
}
最后一个重要的函数就是新样本的识别函数:

void interactive_classification( CvDTree* dtree, const char** var_desc )
{
    char input[1000];
    const CvDTreeNode* root;
    CvDTreeTrainData* data;

    if( !dtree )
        return;

    root = dtree->get_root();
    data = dtree->get_data();

    for(;;)
    {
        const CvDTreeNode* node;
        
        printf( "Start/Proceed with interactive mushroom classification (y/n): " );
        scanf( "%1s", input );
        if( input[0] != 'y' && input[0] != 'Y' ) //开始输入蘑菇的特征参数
            break;
        printf( "Enter 1-letter answers, '?' for missing/unknown value...\n" ); 

        // custom version of predict
        node = root;
        for(;;)
        {
            CvDTreeSplit* split = node->split;
            int dir = 0;
            
            if( !node->left || node->Tn <= dtree->get_pruned_tree_idx() || !node->split )
                break;

            for( ; split != 0; )
            {
                int vi = split->var_idx, j;
                int count = data->cat_count->data.i[vi];
                const int* map = data->cat_map->data.i + data->cat_ofs->data.i[vi];

                printf( "%s: ", var_desc[vi] );
                scanf( "%1s", input );

                if( input[0] == '?' )
                {
                    split = split->next;
                    continue;
                }

                // convert the input character to the normalized value of the variable
                for( j = 0; j < count; j++ )
                    if( map[j] == input[0] )
                        break;
                if( j < count )
                {
                    dir = (split->subset[j>>5] & (1 << (j&31))) ? -1 : 1;
                    if( split->inversed )
                        dir = -dir;
                    break;
                }
                else
                    printf( "Error: unrecognized value\n" );
            }
            
            if( !dir )
            {
                printf( "Impossible to classify the sample\n");
                node = 0;
                break;
            }
            node = dir < 0 ? node->left : node->right;
        }

        if( node )
            printf( "Prediction result: the mushroom is %s\n",
                    node->class_idx == 0 ? "EDIBLE" : "POISONOUS" );//判断输出有毒或可食用两类
        printf( "\n-----------------------------\n" );
    }
}






评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值