统计学习方法——提升方法(三)

提升方法(三)

现在我们来看一个AdaBoost算法的简单实现。

数据

这里我们采用几个简单数据作为训练数据:

######训练样本######
def load_simple_data():
    data_mat = np.matrix([[1.0,2.1],[2.0,1.1],[1.3,1.0],[1.0,1.0],[2.0,1.0]])
    class_labels=[1.0,1.0,-1.0,-1.0,1.0]
    return data_mat,class_labels

实现

######根据阈值判断每个特征的输出值######
#ret_array初始化时不能全1或-1
def stump_classify(data_matrix,dimen,threshval,thresh_ineq):
    ret_array=zeros((shape(data_matrix)[0],1))
    if thresh_ineq=="lt":
        ret_array[data_matrix[:,dimen]<=threshval]=-1.0
    else:
        ret_array[data_matrix[:,dimen]>threshval]=1.0
    return ret_array   

 ######实现单层决策树######
#lt_predicted_arr存储小于等于阈值的输出值,gt_predicted_arr存储大于阈值的输出值
def build_stump(data_arr,class_labels,D):
    data_matrix=mat(data_arr)   #数据集
    label_mat=mat(class_labels).T  #标签集
    m,n=shape(data_matrix)   #数据集的行、列
    num_steps=10.0      #用于设置划分点
    best_stump={}
    best_class_est=mat(zeros((m,1)))
    min_error=inf
    
    for i in range(n):
        range_min=data_matrix[:,i].min()   #数据集中某一列的最小值
        range_max=data_matrix[:,i].max()   #数据集中某一列的最大值
        step_size=(range_max-range_min)*1.0/num_steps   #用于设置划分点
        for j in range(-1,int(num_steps)+1):
            thresh_val=(range_min+float(j)*step_size)   #计算其中一个划分点
            lt_predicted_arr=zeros((m,1))     #获得析域不等号的值
            gt_predicted_arr=zeros((m,1))     #获得大于不等号的值
            predicted_arr=zeros((m,1))        #获得最终的预测值
            for inequal in ["lt","gt"]:
                preducted_vals=stump_classify(data_matrix,i,thresh_val,inequal)
                if inequal=='lt':
                    lt_predicted_arr=preducted_vals
                else:
                    gt_predicted_arr=preducted_vals
            for k in range(m):
                predicted_arr[k] = lt_predicted_arr[k]
                if(gt_predicted_arr[k] != 0):
                    predicted_arr[k] = gt_predicted_arr[k]
            err_arr = mat(ones((m, 1)))
            err_arr[predicted_arr == label_mat] = 0
            weight_error = D.T * err_arr
            print("min_error = %0.5f, split: dim %d, thresh %0.2f,the weighted error is %0.3f" %(min_error, i, thresh_val, weight_error))
            if(weight_error < min_error):
                min_error = weight_error
                best_class_est = predicted_arr.copy()
                best_stump["dim"] = i
                best_stump["thresh"] = thresh_val
                best_stump["class_est"] = best_class_est
    return best_stump,min_error,best_class_est

###根据单层决策树得到一个弱分类器保存在列表中
###根据得到的弱分类器输出的预测值和原始值修改权重,减小误差小的点的权重,增大误差大的点的权重。最后误差率等于0则停止迭代。
def adaboost_train_DS(data_arr, class_labels, num_it = 40):
    weak_class_arr = []
    m = shape(data_arr)[0]
    D = mat(ones((m, 1)) / m)
    agg_class_est = mat(zeros((m, 1)))
    counts = 0

    for i in range(num_it):
        best_stump,error,class_est = build_stump(data_arr, class_labels, D)
        print("D: ", D.T)
        alpha = float(0.5 * log((1.0 - error) / error))
        best_stump["alpha"] = alpha
        weak_class_arr.append(best_stump)
        print("class_est: ", class_est.T)
        expon = multiply(-1 * alpha * mat(class_labels).T, class_est)
        D = multiply(D, exp(expon))
        D = D / D.sum()
        agg_class_est += alpha * class_est  #预测值
        print("agg_class_est: ", agg_class_est.T)
        agg_errors = multiply(sign(agg_class_est) != mat(class_labels).T, ones((m, 1)))
        error_rate = agg_errors.sum() / m
        print("total error: ", error_rate)
        counts += 1
        if(error_rate == 0.0):
            break
    return weak_class_arr,counts

###弱分类器:根据每一个弱分类器阈值,输出每个特征的输出值。
def Gx(data_matrix, reshval):
    m = shape(data_matrix)[0]
    result_data = mat(ones((m, 1)))
    result_data[data_matrix[:, 0] <= reshval] = -1.0
    result_data[data_matrix[:, 0] > reshval] = 1.0
    return result_data

###根据输出数据,输出实际预测值。
def ada_classfiy(input_data, weak_class_arr, counts):
    data_matrix = mat(input_data)
    m,n = shape(data_matrix)  
    agg_class_est = mat(zeros((m, 1)))
    for j in range(n):
        for i in range(len(weak_class_arr)):
            agg_class_est += (weak_class_arr[i]["alpha"] * (Gx(input_data[:, j], weak_class_arr[i]["thresh"])).T).T
    print(agg_class_est)
    return sign(agg_class_est)

###画图###
def experiment_plot(data_matrix, agg_class_est):
    data_arr_in = data_matrix.getA()
    label_arr_in = agg_class_est.getA()
    m,n = shape(data_matrix)

    for i in range(m):
        if(label_arr_in[i, 0] == -1):
            plt.plot(data_arr_in[i, 0], data_arr_in[i, 1], "ob")
        elif(label_arr_in[i, 0] == 1):
            plt.plot(data_arr_in[i, 0], data_arr_in[i, 1], "or")

    plt.xlabel("X")
    plt.ylabel("Y")
    plt.show()

测试与结果

def main():
    #D = mat(ones((5, 1)) / 5.0)
    data_mat,class_labels = load_simple_data()
    #build_stump(data_mat, class_labels, D)
    weak_class_arr,counts = adaboost_train_DS(data_mat, class_labels)
    print("**************************")
    data_mat2 = matrix([[1.2, 2.0],
                       [1.0, 1.0],
                       [0.6, 1.2],
                       [0.8, 1.1],
                       [1.8, 1.0],
                       [2.0, 0.4],
                       [1.7, 0.8],
                       [3.5, 0.2],
                       [2.5, 0.8],
                       [2.8, 0.9]])
    agg_class_est = ada_classfiy(data_mat2, weak_class_arr, counts)
    print("--------------------------")
    print(agg_class_est)
    experiment_plot(data_mat2, agg_class_est)

main()

结果如下所示:


min_error = inf, split: dim 0, thresh 0.90,the weighted error is 0.400
min_error = 0.40000, split: dim 0, thresh 1.00,the weighted error is 0.400
min_error = 0.40000, split: dim 0, thresh 1.10,the weighted error is 0.400
min_error = 0.40000, split: dim 0, thresh 1.20,the weighted error is 0.400
min_error = 0.40000, split: dim 0, thresh 1.30,the weighted error is 0.200
min_error = 0.20000, split: dim 0, thresh 1.40,the weighted error is 0.200
min_error = 0.20000, split: dim 0, thresh 1.50,the weighted error is 0.200
min_error = 0.20000, split: dim 0, thresh 1.60,the weighted error is 0.200
min_error = 0.20000, split: dim 0, thresh 1.70,the weighted error is 0.200
min_error = 0.20000, split: dim 0, thresh 1.80,the weighted error is 0.200
min_error = 0.20000, split: dim 0, thresh 1.90,the weighted error is 0.200
min_error = 0.20000, split: dim 0, thresh 2.00,the weighted error is 0.600
min_error = 0.20000, split: dim 1, thresh 0.89,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.00,the weighted error is 0.200
min_error = 0.20000, split: dim 1, thresh 1.11,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.22,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.33,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.44,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.55,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.66,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.77,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.88,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 1.99,the weighted error is 0.400
min_error = 0.20000, split: dim 1, thresh 2.10,the weighted error is 0.600
D: [[0.2 0.2 0.2 0.2 0.2]]
class_est: [[-1. 1. -1. -1. 1.]]
agg_class_est: [[-0.69314718 0.69314718 -0.69314718 -0.69314718 0.69314718]]
total error: 0.2
min_error = inf, split: dim 0, thresh 0.90,the weighted error is 0.250
min_error = 0.25000, split: dim 0, thresh 1.00,the weighted error is 0.625
min_error = 0.25000, split: dim 0, thresh 1.10,the weighted error is 0.625
min_error = 0.25000, split: dim 0, thresh 1.20,the weighted error is 0.625
min_error = 0.25000, split: dim 0, thresh 1.30,the weighted error is 0.500
min_error = 0.25000, split: dim 0, thresh 1.40,the weighted error is 0.500
min_error = 0.25000, split: dim 0, thresh 1.50,the weighted error is 0.500
min_error = 0.25000, split: dim 0, thresh 1.60,the weighted error is 0.500
min_error = 0.25000, split: dim 0, thresh 1.70,the weighted error is 0.500
min_error = 0.25000, split: dim 0, thresh 1.80,the weighted error is 0.500
min_error = 0.25000, split: dim 0, thresh 1.90,the weighted error is 0.500
min_error = 0.25000, split: dim 0, thresh 2.00,the weighted error is 0.750
min_error = 0.25000, split: dim 1, thresh 0.89,the weighted error is 0.250
min_error = 0.25000, split: dim 1, thresh 1.00,the weighted error is 0.125
min_error = 0.12500, split: dim 1, thresh 1.11,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 1.22,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 1.33,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 1.44,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 1.55,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 1.66,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 1.77,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 1.88,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 1.99,the weighted error is 0.250
min_error = 0.12500, split: dim 1, thresh 2.10,the weighted error is 0.750
D: [[0.5 0.125 0.125 0.125 0.125]]
class_est: [[ 1. 1. -1. -1. -1.]]
agg_class_est: [[ 0.27980789 1.66610226 -1.66610226 -1.66610226 -0.27980789]]
total error: 0.2
min_error = inf, split: dim 0, thresh 0.90,the weighted error is 0.143
min_error = 0.14286, split: dim 0, thresh 1.00,the weighted error is 0.357
min_error = 0.14286, split: dim 0, thresh 1.10,the weighted error is 0.357
min_error = 0.14286, split: dim 0, thresh 1.20,the weighted error is 0.357
min_error = 0.14286, split: dim 0, thresh 1.30,the weighted error is 0.286
min_error = 0.14286, split: dim 0, thresh 1.40,the weighted error is 0.286
min_error = 0.14286, split: dim 0, thresh 1.50,the weighted error is 0.286
min_error = 0.14286, split: dim 0, thresh 1.60,the weighted error is 0.286
min_error = 0.14286, split: dim 0, thresh 1.70,the weighted error is 0.286
min_error = 0.14286, split: dim 0, thresh 1.80,the weighted error is 0.286
min_error = 0.14286, split: dim 0, thresh 1.90,the weighted error is 0.286
min_error = 0.14286, split: dim 0, thresh 2.00,the weighted error is 0.857
min_error = 0.14286, split: dim 1, thresh 0.89,the weighted error is 0.143
min_error = 0.14286, split: dim 1, thresh 1.00,the weighted error is 0.500
min_error = 0.14286, split: dim 1, thresh 1.11,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 1.22,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 1.33,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 1.44,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 1.55,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 1.66,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 1.77,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 1.88,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 1.99,the weighted error is 0.571
min_error = 0.14286, split: dim 1, thresh 2.10,the weighted error is 0.857
D: [[0.28571429 0.07142857 0.07142857 0.07142857 0.5 ]]
class_est: [[1. 1. 1. 1. 1.]]
agg_class_est: [[ 1.17568763 2.56198199 -0.77022252 -0.77022252 0.61607184]]
total error: 0.0

[[ 3.73766962e+00]
[-1.54044504e+00]
[-1.38629436e+00]
[-1.38629436e+00]
[ 1.79175947e+00]
[-2.22044605e-16]
[-2.22044605e-16]
[-2.22044605e-16]
[-2.22044605e-16]
[-2.22044605e-16]]

[[ 1.]
[-1.]
[-1.]
[-1.]
[ 1.]
[-1.]
[-1.]
[-1.]
[-1.]
[-1.]]

变化


完整代码

完整代码:https://github.com/canshang/-/blob/master/AdaBoost.ipynb

参考文献

《统计学习方法》
《机器学习实战》
https://www.jianshu.com/p/2ea2f4ec121c?t=123&utm_source=oschina-app

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值