AdaBoost - 李航 8.1.3 - 代码实现
《统计学习方法》（第2版）p158

x0123456789
y111-1-1-1111-1
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  5 14:30:03 2019

@author: puffy
"""

import numpy as np

x = np.array(range(10))
y = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])

def stump_classifer(x, thre, direction):
if direction == 1:
return np.array(x < thre).astype(int) - np.array(x >= thre).astype(int)
else:
return np.array(x > thre).astype(int) - np.array(x <= thre).astype(int)

def error_empirical(y, y_pre, w_m):
em = np.sum(np.array(y != y_pre).astype(int) * w_m)
return em

def grid_search_thre(x, y, w, num):
grid = np.linspace(min(x), max(x), num)
em = []
flag = []
for thre in grid:
y_pre_1 = stump_classifer(x, thre, 1)
y_pre_2 = stump_classifer(x, thre, 2)

em_1 = error_empirical(y, y_pre_1, w)
em_2 = error_empirical(y, y_pre_2, w)

if em_1 <= em_2:
em.append(em_1)
flag.append(1)
else:
em.append(em_2)
flag.append(2)

thre = grid[np.argmin(em)]
flag = flag[np.argmin(em)]

return flag, thre

def alpha_m(em):
return 0.5 * np.log((1 - em) / em)

def update_w(w, alpha, y, y_pre):
expm = np.exp(-alpha * y * y_pre)
zm = np.sum(w * expm)
wm = w * expm / zm

return wm

model_info = {}
y_pre_weight = np.zeros((y.shape[0], m))
w_m = np.zeros(x.shape[0]) + (1 / x.shape[0])

for i in range(m):
flag, thre = grid_search_thre(x, y, w_m, 30)
model_info[i] = [flag, thre]
y_pre = stump_classifer(x, thre, flag)
e = error_empirical(y, y_pre, w_m)
alpha = alpha_m(e)
model_info[i].append(alpha)
w_m = update_w(w_m, alpha, y, y_pre)
y_pre_weight[:, i] = y_pre * alpha

x_final = np.sum(y_pre_weight, axis=1)
y_pre = np.sign(x_final)

return model_info

def ababoost_predict(x , model_info):
m = len(model_info)

y_pre_weight = np.zeros((x.shape[0], m))

for key, info in model_info.items():
y_pre = stump_classifer(x, info[1], info[0])
y_pre_weight[:, key] = y_pre * info[2]

x_final = np.sum(y_pre_weight, axis=1)
y_pre = np.sign(x_final)

return y_pre

if __name__ == '__main__':

"""要保留每一轮m的thre;wm;flag;m
要实现输入一个x就有预测的y跑出来"""
y_pre = ababoost_predict(x,model_info)



