决策树
一、决策树理解
决策树是监督学习中的分类算法。决策树的构造是通过信息增益来选择最佳的分类特征值,根据训练集对应特征值取值不同分成若干子集,通过递归的形式处理子集,直到子集中的所有样本同属一类或者已无特征值可分。决策树就是通过多个条件判断组成的,而其关键就是如何选取最适合分类的特征值。
二、最适合的特征值的判断标准——信息增益
- 信息:
x i x_i xi的信息为: l ( x i ) = − log 2 p ( x i ) l(x_i) = - \log_2 p(x_i) l(xi)=−log2p(xi),其中 p ( x i ) p(x_i) p(xi)是 x i x_i xi出现的概率。
从公式中可以看出,概率越大,信息越小。我们可以这样理解,概率大的情况出现了是很正常的事情,正常的事情给我们带来的信息也就无关紧要,所以其对应的信息小。相反,概率越小,信息越大 - 熵定义为信息的期望值:
H = − ∑ i = 1 n p ( x i ) log 2 p ( x i ) H = - \sum_{i=1}^{n} p(x_i) \log_2 p(x_i) H=−∑i=1np(xi)log2p(xi)
熵越大代表着数据更加杂乱无序 - 信息增益: 熵的减少或信息无序度的减少。寻找最适合的特征值就是,最大的信息增益。
三、算法
from math import log
import operator as operator
'''
data_set: 二维数组,训练集数据,第i行对应第i个特征值,最后一行为分类结果
labels:数组,对应data_set中特征值的描述
'''
#计算data_set的香农熵
def calc_shannon_ent(data_set):
label_count = {}
m = len(data_set)
for i in range(m):
label = data_set[i][-1]
if label not in label_count.keys():
label_count[label] = 0
label_count[label] += 1
shannon_ent = 0
for key in label_count.keys():
prob = label_count[key] / float(m)
shannon_ent -= prob * log(prob, 2)
return shannon_ent
'''
axis: 用于划分的特征值对应索引
value:用于划分的特征值取值
ret_data_set: 返回值,返回axis对应特征值的取值为value的样本集合
'''
def split_data_set(data_set, axis, value):
ret_data_set = []
for sample in data_set:
if sample[axis] == value:
#新集合里不再含有已经用于划分的特征值
reduce_feat = sample[:axis]
reduce_feat.extend(sample[(axis + 1):])
ret_data_set.append(reduce_feat)
return ret_data_set
#选取最适合的特征值
def choose_best_feature_to_split(data_set):
base_shannon_ent = calc_shannon_ent(data_set)
feature_num = len(data_set[0])-1
m = len(data_set)
max_info_gain = 0
best_feature = 0
for i in range(feature_num):
feature_set = set([sample[i] for sample in data_set])
split_shannon_ent = 0
for feature_val in feature_set:
sub_data_set = split_data_set(data_set, i, feature_val)
prob = len(sub_data_set) / float(m)
#通过各个子集的熵,计算划分后的总熵
split_shannon_ent += prob * calc_shannon_ent(sub_data_set)
info_gain = base_shannon_ent - split_shannon_ent
if info_gain > max_info_gain:
max_info_gain = info_gain
best_feature = i
return best_feature
#选出class_list中最多的class
def majority_cnt(class_list):
class_count = {}
for c in class_list:
if c not in class_count.keys():
class_count[c] = 0
class_count[c] += 1
sorted_class = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
return sorted_class[0][0]
#决策树构建(训练过程)
def create_tree(data_set, labels):
class_list = [sample[-1] for sample in data_set]
#无特征值可用的情况
if len(data_set) == 1:
return majority_cnt(class_list)
#子集中所有样本同类,无需再分
if class_list.count(class_list[0]) == len(data_set):
return class_list[0]
#找出最适合特征值,根据特征值分类
best_feature = choose_best_feature_to_split(data_set)
best_label = labels.pop(best_feature)
#用字典的结构存决策树
tree = {best_label:{}}
best_feature_set = set([sample[best_feature] for sample in data_set])
for val in best_feature_set:
sub_data_set = split_data_set(data_set, best_feature, val)
#对子集递归处理
tree[best_label][val] = create_tree(sub_data_set, labels.copy())
return tree
'''
测试,应用
input_tree:上述用字典表示的决策树
feat_labels:特征值标签
test_vec:数组,测试分类数据
'''
def classify(input_tree,feat_labels,test_vec):
feat_label = list(input_tree.keys())[0]
inside_tree = input_tree[feat_label]
key = test_vec[feat_labels.index(feat_label)]
result = inside_tree[key]
if isinstance(result, dict):
return classify(result, feat_labels,test_vec)
else:
return result
四、其他
- 使用的构造决策树算法:ID3
- 计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征值数据
- 存在过度匹配问题,可以裁剪决策树