从头开始实现决策树,并将其应用于对蘑菇是可食用还是有毒的分类任务。
1 导包
import numpy as np
import matplotlib.pyplot as plt
from public_tests import *
%matplotlib inline
2 问题描述
假设您正在创办一家种植和销售野生蘑菇的公司。 由于并非所有蘑菇都可以食用,因此您希望能够根据其物理属性来判断给定的蘑菇是可食用的还是有毒的 您有一些可用于此任务的现有数据。 你能用这些数据来帮助你确定哪些蘑菇可以安全销售吗? 注意:使用的数据集仅用于说明目的。它并不意味着作为识别食用蘑菇的指南。
3 one-hot 编码数据集
Brown Cap | Tapering Stalk Shape | Solitary | Edible |
---|---|---|---|
1 | 1 | 1 | 1 |
1 | 0 | 1 | 1 |
1 | 0 | 0 | 0 |
1 | 0 | 0 | 0 |
1 | 1 | 1 | 1 |
0 | 1 | 1 | 0 |
0 | 0 | 0 | 0 |
1 | 0 | 1 | 1 |
0 | 1 | 0 | 1 |
1 | 0 | 0 | 0 |
因此, |
-
X_train
包含每个样本的三个特征- 帽子颜色(值为
1
表示棕色帽子,值为0
表示红色帽子) - 茎形缩小(值为
1
表示“锥形茎形”,值为0
表示“扩大”茎形) - 单独(值为
1
表示“是”,值为0
表示“否”)
- 帽子颜色(值为
-
y_train
是蘑菇是否可食用y = 1
表示可食用y = 0
表示有毒
X_train = np.array([[1,1,1],[1,0,1],[1,0,0],[1,0,0],[1,1,1],[0,1,1],[0,0,0],[1,0,1],[0,1,0],[1,0,0]])
y_train = np.array([1,1,0,0,1,0,0,1,1,0])
X_train[:5]
array([[1, 1, 1],
[1, 0, 1],
[1, 0, 0],
[1, 0, 0],
[1, 1, 1]])
4决策树
在这个实践实验中,您将基于提供的数据集构建一个决策树。
-
回想一下构建决策树的步骤:
- 从根节点开始使用所有的样本
- 计算基于所有可能特征分割时的信息增益,并选择具有最高信息增益的特征
- 根据所选特征对数据集进行分割,并创建树的左右分支
- 持续重复分裂过程直到满足停止标准
-
在本实验中,您将实现以下函数,以便使用具有最高信息增益的特征将节点分成左右两个分支:
- 计算节点上的熵
- 基于给定特征在一个节点上将数据集分为左右两个分支
- 计算在给定特征上分裂时的信息增益
- 选择最大化信息增益的特征
-
然后我们将使用您实现的辅助函数通过重复分裂过程来构建决策树,直到满足停止标准为止。
- 对于本实验,我们选择的停止标准是设置最大深度为2。
4.1 计算熵
首先,您需要编写一个名为compute_entropy
的帮助函数,用于计算节点上的熵(杂质度量)。
- 函数接受一个numpy数组(
y
),该数组指示该节点中的示例是否可食用(1
)或有毒(0
)
请完成下面的compute_entropy()
函数以:
- 计算 p 1 p_1 p1,它是可食用示例(即在
y
中具有值=1
)的比例 - 然后计算熵
H ( p 1 ) = − p 1 log 2 (