↑ 点击上方【计算机视觉联盟】关注我们
上一篇已经介绍过决策树基本原理机器学习经典算法决策树原理详解(简单易懂)
纸上得来终觉浅,仅仅懂了原理还不够,要用代码实践才是王道,今天小编就附上小编自己在学习中实践的决策树算法。
1、信息增益
计算给定数据集的熵:
1def calc_shannon_ent(data_set):
2 """计算给定数据集的熵"""
3 num_entries = len(data_set) # 数据集中实例的总数
4
5 # 创建数据字典,键值是最后一列的数值。如果当前键值不存在,则扩展字典并将当前键值加入字典
6 # 每个键值都记录了当前类别出现的次数
7 label_counts = {} # 创建数据字典
8 for feat_vec in data_set:
9 current_label = feat_vec[-1] # 键值是最后一列的数值,表示类别标签
10 # 如果当前键值不存在,则扩展字典并将当前键值加入字典
11 if current_label not in label_counts.keys():
12 label_counts[current_label] = 0
13 label_counts[current_label] += 1
14
15 # 使用所有类标签的发生频率来计算类别出现的概率,并用这个概率来计算熵,统计所有类标签发生的次数
16 shannon_ent = 0
17 for key in label_counts:
18 prob = float(label_counts[key])/num_entries # 计算类标签的概率
19 shannon_ent -= prob * log(prob, 2) # 计算熵
20 return shannon_ent
2、划分数据集
对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式
1# data_set表示待划分的数据集,axis为划分数据集的特征,value指需要返回的特征的值
2def split_data_set(data_set, axis, value):
3 """按照给定的特征划分数据集"""
4 # Python语言在函数中传递的是列表的引用。在函数内部对对象的修改,将会影响该列表对象的整个生存周期。
5 # 为了消除这个不良影响,我们声明一个新列表对象(ret_data_set),用来存储符合要求的值
6 ret_data_set = []
7
8 for feat_vec in data_set:
9 # print(feat_vec)
10 # 将符合特征特征的数据抽取出来
11 if feat_vec[axis] == value:
12 reduced_feat_vec = feat_vec[: axis] # 符合特征值的前边的数据(特征位置之前的数据)
13 # print(reduced_feat_vec)
14 reduced_feat_vec.extend(feat_vec[axis+1:]) # 符合特征值的后边数据(特征位置之后的数据)
15 # print(reduced_feat_vec)
16 ret_data_set.append(reduced_feat_vec)
17 return ret_data_set
代码过程:
1、输入三个参数:带划分的数据集、划分数据集的特征、需要返回的特征的值
2、Python语言在函数中传递的是列表的引用,在函数内部对列表对列表对象的修改,将会影响该列表对象的整个生存周期。为了不修改原始数据集,需要在函数的开始声明一个新列表对象,ret_data_set=[]
3、代码中使用extend和append方法(Python中append()和extend方法的使用和区别)
3、选择最好的数据集划分方式
1def choose_best_feature_to_split(data_set):
2 """选择最好的数据集划分"""
3 num_features = len(data_set[0])-1 # 数据集特征的个数
4 base_entropy = calc_shannon_ent(data_set) # 计算数据集的熵
5 best_info_gain = 0 # 初始化信息最优信息增益
6 best_