已知如下表所示的训练数据,x 的取值范围为区间[0.5,10.5],y 的取值范围为区间[5.0,10.0],
学习这个回归问题的提升树模型,考虑只用树桩作为基函数。
xi | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
yi | 5.56 | 5.70 | 5.91 | 6.40 | 6.80 | 7.05 | 8.90 | 8.70 | 9.00 | 9.05 |
终止条件:模型的平方差损失函数误差小于 0.2
注:树桩是由一个根节点直接连接两个叶结点的简单决策树
import numpy as np
class Tree:
def __init__(self, node, left_value, right_value, residual):
# 节点,左值,右值,提升树的残差
self.node = node
self.left_value = left_value
self.right_value = right_value
self.residual = residual
x = np.array([i for i in range(1,11)])
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05])
# y = y.astype(np.float)
# print(y.dtype)
# 每个提升树的生成
def tree_p(y):
mini = np.inf
node, c1, c2 = 0, 0, 0
for i in np.arange(x[0] + 0.5, x[-2] + 0.5, 1):
y_l = y[x < i]
y_ls = sum((y_l-y_l.mean())**2)
y_r = y[x > i]
y_lr = sum((y_r-y_r.mean())**2)
y_ = y_ls + y_lr
# 注意转换类型
# y_ = np.float(y_)
# print(type(mini),type(y_))
# print(y_)
if y_ < np.float(mini):
node = i
mini = '{:.2f}'.format(y_)
# 二分类左边及右边的预测取值
c1 = '{:.2f}'.format(y_l.mean())
c2 = '{:.2f}'.format(y_r.mean())
# print(node, mini, c1, c2)
# 找残差
yy = y.copy()
yy[x<node] = yy[x<node] - np.float(c1)
yy[x>node] = yy[x>node] - np.float(c2)
# print(sum(y**2))
tree = Tree(node, c1, c2, yy)
# print(y,yy)
return tree
# 模型的平方差损失函数误差
def get_error(y, predict):
return np.sum((y-predict)**2)
# 得到每个提升树模型的预测值
def predict(tree, predict_):
predict_[x<tree.node] += np.float(tree.left_value)
predict_[x>tree.node] += np.float(tree.right_value)
return predict_
# 生成提升树直到模型的平方差损失函数误差小于 0.2
def bdt_model():
yyy = y.copy()
predict_ = np.zeros(np.shape(x)[0])
for i in range(100):
# 生成树
tree = tree_p(yyy)
# 树的残差
yyy = tree.residual
# 更新预测值
predict_ = predict(tree, predict_)
# print(yyy, predict_)
if get_error(y, predict_) < 0.2:
# print(tree.node)
print('共运行 '+ str(i + 1) + ' 次,' + '最终平方损失误差是 {}'.format('%.2f' % get_error(y, predict_)))
print('最终分类预测结果为')
return predict_
else:
print(predict_)
print(bdt_model())
运行结果
[6.24 6.24 6.24 6.24 6.24 6.24 8.91 8.91 8.91 8.91]
[5.72 5.72 5.72 6.46 6.46 6.46 9.13 9.13 9.13 9.13]
[5.87 5.87 5.87 6.61 6.61 6.61 8.91 8.91 8.91 8.91]
[5.71 5.71 5.71 6.45 6.72 6.72 9.02 9.02 9.02 9.02]
[5.78 5.78 5.78 6.52 6.79 6.79 8.91 8.91 8.91 8.91]
共运行 6 次,最终平方损失误差是 0.17
最终分类预测结果为
[5.56 5.8 5.8 6.54 6.81 6.81 8.93 8.93 8.93 8.93]