课后题5.2
import numpy as np
# 原始数据
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y = np.array([4.50, 4.75, 4.91, 5.34, 5.80, 7.05, 7.90, 8.23, 8.70, 9.00])
# 建立简单的树类
class DecisionTree():
def __init__(self, val, name, left_val, right_val):
self.value = val
self.name = name
self.left = left_val
self.right = right_val
# 储存节点对象的列表
nodes = []
# 根据平方误差原则生成子节点的递归函数
def cutPoint(x, y,):
# 如果只由一个元素则不必再生成子节点,返回叶子节点完成递归,但不再将其加入到 nodes 列表
if len(x)==1:
return DecisionTree(y, y.mean(), 0, 0)
else:
# 存储挑选分割点(splitting point)时,不同分割点产生的平方误差大小
errs = []
# 将原来的样本,按照特征数值(x值)升序排列,方便分割时,分割点左边的样本特征值都小于右边
y = y[np.argsort(x)]
x.sort()
# 计算在不同分割点处平方误差大小,并存入 errs 列表
for i in range(len(x)):
err1, err2 = 0, 0
y1 = y[:i]
y2 = y[i:]
if y1.size>0:
err1 = sum((y1 - y1.mean())**2)
elif y2.size>0:
err2 = sum((y2 - y2.mean())**2)
else:
raise
err = err1 + err2
errs.append(err)
# 挑选平方误差最小的分割点
idx = errs.index(min(errs))
spliting_point = x[idx]
spliting_condition = y[idx] # 分割点对应的样本的y值
# 建立父子节点对象(DecisionTree Object)
val = (x, y)
name = y.mean()
left_val = (x[:idx], y[:idx])
right_val = (x[idx:], y[idx:])
nodes.append(DecisionTree(val, name, left_val, right_val))
# 将分割后的两类,即子节点当作父节点递归调用 cutPoint
cutPoint(x[:idx], y[:idx])
cutPoint(x[idx:], y[idx:])
cutPoint(x,y)
for idx, itm in enumerate(nodes):
print(f'node{idx+1}:', itm.value, '\nmean:',itm.name, '\nleft:',itm.left, '\nright:', itm.right, '\n')
node1: (array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), array([4.5 , 4.75, 4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
mean: 6.618
left: (array([1]), array([4.5]))
right: (array([ 2, 3, 4, 5, 6, 7, 8, 9, 10]), array([4.75, 4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
node2: (array([ 2, 3, 4, 5, 6, 7, 8, 9, 10]), array([4.75, 4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
mean: 6.8533333333333335
left: (array([2]), array([4.75]))
right: (array([ 3, 4, 5, 6, 7, 8, 9, 10]), array([4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
node3: (array([ 3, 4, 5, 6, 7, 8, 9, 10]), array([4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
mean: 7.11625
left: (array([3]), array([4.91]))
right: (array([ 4, 5, 6, 7, 8, 9, 10]), array([5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
node4: (array([ 4, 5, 6, 7, 8, 9, 10]), array([5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
mean: 7.431428571428573
left: (array([4]), array([5.34]))
right: (array([ 5, 6, 7, 8, 9, 10]), array([5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
node5: (array([ 5, 6, 7, 8, 9, 10]), array([5.8 , 7.05, 7.9 , 8.23, 8.7 , 9. ]))
mean: 7.78
left: (array([5]), array([5.8]))
right: (array([ 6, 7, 8, 9, 10]), array([7.05, 7.9 , 8.23, 8.7 , 9. ]))
node6: (array([ 6, 7, 8, 9, 10]), array([7.05, 7.9 , 8.23, 8.7 , 9. ]))
mean: 8.175999999999998
left: (array([6]), array([7.05]))
right: (array([ 7, 8, 9, 10]), array([7.9 , 8.23, 8.7 , 9. ]))
node7: (array([ 7, 8, 9, 10]), array([7.9 , 8.23, 8.7 , 9. ]))
mean: 8.4575
left: (array([7]), array([7.9]))
right: (array([ 8, 9, 10]), array([8.23, 8.7 , 9. ]))
node8: (array([ 8, 9, 10]), array([8.23, 8.7 , 9. ]))
mean: 8.643333333333333
left: (array([8]), array([8.23]))
right: (array([ 9, 10]), array([8.7, 9. ]))
node9: (array([ 9, 10]), array([8.7, 9. ]))
mean: 8.85
left: (array([9]), array([8.7]))
right: (array([10]), array([9.]))
上面一共产生了9个节点,为决策树的中间节点(internal node),其中根节点按照书上定义,也属于中间节点,所有在“left”,”right“当中,只由单一样本的为叶子节点。对照前人用python2写出的没有递归的程序,最终答案是一致的,下图为前辈的答案: