《统计学习方法(第2版)》李航第五章决策树课后习题5.2答案(使用python3编写,递归算法)

课后题5.2
转自:https://zhuanlan.zhihu.com/p/166393579

import numpy as np

# 原始数据
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y = np.array([4.50, 4.75, 4.91, 5.34, 5.80, 7.05, 7.90, 8.23, 8.70, 9.00])


# 建立简单的树类
class DecisionTree():
    def __init__(self, val, name, left_val, right_val):
        self.value = val
        self.name = name
        self.left = left_val
        self.right = right_val
        
# 储存节点对象的列表        
nodes = []


# 根据平方误差原则生成子节点的递归函数
def cutPoint(x, y,):
    # 如果只由一个元素则不必再生成子节点,返回叶子节点完成递归,但不再将其加入到 nodes 列表
    if len(x)==1:
        return DecisionTree(y, y.mean(), 0, 0)
    else:
        # 存储挑选分割点(splitting point)时,不同分割点产生的平方误差大小
        errs = []
        # 将原来的样本,按照特征数值(x值)升序排列,方便分割时,分割点左边的样本特征值都小于右边
        y = y[np.argsort(x)]
        x.sort()
        # 计算在不同分割点处平方误差大小,并存入 errs 列表
        for i in range(len(x)):
            err1, err2 = 0, 0
            y1 = y[:i]
            y2 = y[i:]
            if y1.size>0:
                err1 = sum((y1 - y1.mean())**2)
            elif y2.size>0:
                err2 = sum((y2 - y2.mean())**2)
            else:
                raise 
            err = err1 + err2
            errs.append(err)
        # 挑选平方误差最小的分割点
        idx = errs.index(min(errs))
        spliting_point = x[idx]
        spliting_condition = y[idx]  # 分割点对应的样本的y值
        # 建立父子节点对象(DecisionTree Object)
        val = (x, y)
        name = y.mean()
        left_val = (x[:idx], y[:idx])
        right_val = (x[idx:], y[idx:])
        nodes.append(DecisionTree(val, name, left_val, right_val))
        # 将分割后的两类,即子节点当作父节点递归调用 cutPoint
        cutPoint(x[:idx], y[:idx])
        cutPoint(x[idx:], y[idx:])  
cutPoint(x,y)
for idx, itm in enumerate(nodes):
    print(f'node{idx+1}:', itm.value, '\nmean:',itm.name, '\nleft:',itm.left, '\nright:', itm.right, '\n')
node1: (array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]), array([4.5 , 4.75, 4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 
mean: 6.618 
left: (array([1]), array([4.5])) 
right: (array([ 2,  3,  4,  5,  6,  7,  8,  9, 10]), array([4.75, 4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 

node2: (array([ 2,  3,  4,  5,  6,  7,  8,  9, 10]), array([4.75, 4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 
mean: 6.8533333333333335 
left: (array([2]), array([4.75])) 
right: (array([ 3,  4,  5,  6,  7,  8,  9, 10]), array([4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 

node3: (array([ 3,  4,  5,  6,  7,  8,  9, 10]), array([4.91, 5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 
mean: 7.11625 
left: (array([3]), array([4.91])) 
right: (array([ 4,  5,  6,  7,  8,  9, 10]), array([5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 

node4: (array([ 4,  5,  6,  7,  8,  9, 10]), array([5.34, 5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 
mean: 7.431428571428573 
left: (array([4]), array([5.34])) 
right: (array([ 5,  6,  7,  8,  9, 10]), array([5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 

node5: (array([ 5,  6,  7,  8,  9, 10]), array([5.8 , 7.05, 7.9 , 8.23, 8.7 , 9.  ])) 
mean: 7.78 
left: (array([5]), array([5.8])) 
right: (array([ 6,  7,  8,  9, 10]), array([7.05, 7.9 , 8.23, 8.7 , 9.  ])) 

node6: (array([ 6,  7,  8,  9, 10]), array([7.05, 7.9 , 8.23, 8.7 , 9.  ])) 
mean: 8.175999999999998 
left: (array([6]), array([7.05])) 
right: (array([ 7,  8,  9, 10]), array([7.9 , 8.23, 8.7 , 9.  ])) 

node7: (array([ 7,  8,  9, 10]), array([7.9 , 8.23, 8.7 , 9.  ])) 
mean: 8.4575 
left: (array([7]), array([7.9])) 
right: (array([ 8,  9, 10]), array([8.23, 8.7 , 9.  ])) 

node8: (array([ 8,  9, 10]), array([8.23, 8.7 , 9.  ])) 
mean: 8.643333333333333 
left: (array([8]), array([8.23])) 
right: (array([ 9, 10]), array([8.7, 9. ])) 

node9: (array([ 9, 10]), array([8.7, 9. ])) 
mean: 8.85 
left: (array([9]), array([8.7])) 
right: (array([10]), array([9.])) 

上面一共产生了9个节点,为决策树的中间节点(internal node),其中根节点按照书上定义,也属于中间节点,所有在“left”,”right“当中,只由单一样本的为叶子节点。对照前人用python2写出的没有递归的程序,最终答案是一致的,下图为前辈的答案:
转自:https://www.processon.com/view/link/59814675e4b02e2de77789ec

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ML--小小白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值