一、详情
1.1 目的:
求解信息增益
1.2 参数如下
- getEntropy§----其中p是一个list类型,元素为属性各取值的概率
二、求解
2.1 代码
import numpy as np
# 求信息熵
def getEntropy(p):
rs = 0
for i in range(len(p)):
rs = rs - np.log2(p[i]) * p[i]
return rs
def calcInfoGain(feature, label, index):
# 先算总的熵
pAll = []
labelS = set(label[:, 0])
m = label[:, 0].shape
for value in labelS:
tmp = label[:, 0][label[:, 0] == value]
pAll.append(len(tmp)/m[0])
entropyAll = getEntropy(pAll)
# 计算index维各属性的熵
featureS = set(feature[:, index])
entropyHeft = []
ws = []
for v1 in featureS:
tmp1 = feature[:, index] == v1
ws.append(len(label[:, 0][tmp1])/m[0])
pTmp1 = []
for v2 in labelS:
tmp2 = label[:, 0][tmp1][label[:, 0][tmp1] == v2]
tmp3 = label[:, 0][tmp1].shape
if len(tmp2) != 0:
pTmp1.append(len(tmp2) / tmp3[0])
entropyHeft_i = getEntropy(pTmp1)
entropyHeft.append(entropyHeft_i)
# 计算信息增益
rs = 0
for i in range(len(ws)):
rs = rs + ws[i] * entropyHeft[i]
rs = entropyAll - rs
return rs
if __name__ == '__main__':
data = np.array([[1],
[0],
[1],
[0],
[1],
[1],
[1],
[0],
[0],
[0],
[0],
[1],
[0],
[1],
[1]])
label = np.array([[0], [0], [1], [0], [0], [0], [1], [0], [1], [0], [0], [1], [1], [0], [0]])
print(calcInfoGain(data, label, 0))
2.2实验截图
三、收获
- 掌握set小trick
- 掌握ndarray布尔型索引用法,
四、备注
- 首先该代码不可直接用于头歌题目,因为头歌所给参数label的类型并不是ndarry,而是list
- 齐次上述实验数据取自头歌题目信息