为表4.3中数据生成一棵决策树。
代码是在《机器学习实战》的代码基础上改良的,借用了numpy, pandas之后明显简化了代码。表4.3的数据特征是离散属性和连续属性都有,问题就复杂在这里。话不多说,看代码。
先定义几个辅助函数,正常的思路是先想宏观算法,然后需要什么函数就定义什么函数。
import math
import pandas as pd
import numpy as np
from treePlotter import createPlot
def entropy(data):
label_values = data[data.columns[-1]]
#Returns object containing counts of unique values.
counts = label_values.value_counts()
s = 0
for c in label_values.unique():
freq = float(counts[c])/len(label_values)
s -= freq*math.log(freq,2)
return s
def is_continuous(data,attr):
"""Check if attr is a continuous attribute"""
return data[attr].dtype == 'float64'
def split_points(data,attr):
"""Returns Ta,Equation(4.7),p.84"""
values = np.sort(data[attr].values)
return [(x+y)/2 for x,y in zip(values[:-1],values[1:])]
treePlotter是《实战》里的模块,用来把决策树画出来。这里决策树是用字典表示的,key可以表示树的节点或分枝,表示节点的时候是属性,表示分枝的时候是属性值。value又是一个字典或字符串,是字符串的时候表示叶,也就是标记。这里的data是pandas里的DataFrame,形式上像一个表,对表的常见操作它都可以方便的解决。命名习惯跟书上一致。
再继续看怎么计算信息增益:
def discrete_gain(data,attr):
V = data[attr].unique()
s = 0
for v in V:
data_v = data[data[attr]== v]
s += float(len(data_v))/len(data)*entropy(data_v)
return (entropy(data) - s,None)