一、香农熵
import numpy as np
import pandas as pd
# 定义熵函数
def calEnt(dataset):
n = dataset.shape[0] #数据集总行数
i = dataset.iloc[:,-1].value_counts() #标签的所有类别
p = i/n #每一类标签所占比
ent = (-p*np.log2(p)).sum() #计算信息熵
return ent
#构建数据集
row_data = {
'accompany':[0,0,0,1,1],
'game':[1,1,0,1,1],
'bad boy':['yes','yes','no','no','no']}
data = pd.DataFrame(row_data)
data
calEnt(data)
熵越高,信息的不纯度就越高,则混合的数据就越多。
二、信息增益
划分数据集
数据集最佳切分函数:准则是选择最大信息增益
def bestSplit(dataSet):
baseEnt = calEnt(dataSet) # 计算原始熵
bestGain = 0 # 初始化信息增益
axis = -1 # 初始化最佳切分列,标签列
for i in range(dataSet.shape[1]-1): # 对特征的每一列进行循环
levels= dataSet.iloc[:,i].value_counts().index # 提取出当前列的所有取值
ents = 0 # 初始化子节点的信息熵
for j in levels: