机器学习基础(十八) —— decision stump

基本原理

decision stump,决策树桩(我称它为一刀切),也称单层决策树(a one level decision tree),单层也就意味着尽可对每一列属性进行一次判断。如下图所示(仅对 petal length 进行了判断):


这里写图片描述

从树(数据结构)的观点来看,它由一个内部节点(internal node)也即根节点(root)与终端节点(terminal node)也即叶子节点(leaves)直接相连。用作分类器(classifier)的 decision stump 的叶子节点也就意味着最终的分类结果。

从实际意义来看,decision stump 根据一个属性的一个判断就决定了最终的分类结果,比如根据水果是否是圆形判断水果是否为苹果,这体现的是单一简单的规则(或叫特征)在起作用。

显然 decision stump 仅可作为一个 weak base learning algorithm(它会比瞎猜 12 稍好一点点,但好的程度十分有限),常用作集成学习中的 base algorithm,而不会单独作为分类器。

既然 decision stump 仅可对一个属性进行一次判断获取最终的分类结果,显然我们寻找具有最低错误率的单层决策树。

所要优化的目标函数为:

argmin1id1Nn=1N1yngi(x)

i 表示属性列,N 为样本集的大小, d <script type="math/tex" id="MathJax-Element-215">d</script> 为属性列的个数。

代码实现

找到具有最低错误率的单层决策树,需要遍历全部的属性列,遍历属性列下所有可能的阈值(当然在一定的步长范围内),以及所有的 True/False 的分配,也即至少需要三层循环。

# 该函数不直接在客户端调用,而被 buildStump 调用
# 如果使用面向对象的思路封装的话,
# 该函数会作为私用成员函数被其他公有成员函数调用
# 如上图所示,这里传递进来的第三个参数是一个具体的数值,
# 更简便,且能力更强的做法是传递进来一个断言 predicate

def stumpClassify(X, j, thresh, ineq):
                                    # X:数据集,j:属性列
                                    # thresh:阈值,ineq:比较
    pred = np.ones(X.shape[0])
    if ineq == 'lf':
        pred[X[:, j] <= thresh] = -1
    else:
        pred[X[:, j] > thresh] = -1
    return pred

def buildStump(X, y, w):
                            # X:数据集,y:labels
                            # w:权值,初始化为 np.ones(N)/N
    N, d = X.shape
    minErr = np.inf
    numStep = 10
    bestStump, bestLabel = {}, np.zeros(N), 

    for j in range(N):
        rangeMin, rangeMax = X[:, j].min(), X[:, j].max()
        stepSize = (rangeMax-rangeMin)/numStep
        for i in range(-1, numStep+1):
            thresh = rangeMin + i*stepSize
            for ineq in ['lt', 'gt']:
                pred = stumpClassify(X, j, thresh, ineq)
                errLabel = np.ones(N)
                errLabel[pred == y] = 1
                weightedErr = w.dot(errLabel)
                if minErr > weightErr:
                    minErr = weightErr
                    bestClass = pred
                    bestStump['dim'] = j
                    bestStump['ineq'] = ineq
                    bestStump['thresh'] = thresh
    return bestStump, minErr, bestClass                 
相关推荐
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页