1. 计算信息熵
def calc_ent(x):
"""
calculate shanno ent of x
"""
x_value_list = set([x[i] for i in range(x.shape[0])])
ent = 0.0
for x_value in x_value_list:
p = float(x[x == x_value].shape[0]) / x.shape[0]
logp = np.log2(p)
ent -= p * logp
return ent
2. 计算条件信息熵
def calc_condition_ent(x, y):
"""
calculate ent H(y|x)
"""
# calc ent(y|x)
x_value_list = set([x[i] for i in range(x.shape[0])])
ent = 0.0
for x_value in x_value_list:
sub_y = y[x == x_value]
temp_ent = calc_ent(sub_y)
ent += (float(sub_y.shape[0]) / y.shape[0]) * temp_ent
return ent
3. 计算信息增益
ent_prap = H(Y) - H(Y|X)
def calc_ent_grap(x,y):
"""
calculate ent grap
"""
base_ent = calc_ent(y)
condition_ent = calc_condition_ent(x, y)
ent_grap = base_ent - condition_ent
return ent_grap