决策树是个极其易懂的算法,建好模型后就是一连串嵌套的if..else...或嵌套的switch。
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据;
缺点:可能会产生过度匹配的问题;
适用数据类型:数值型和标称型。
决策树的Python实现:
(一)先实现几个工具函数:计算熵函数,划分数据集工具函数,计算最大概率属性;
(1)计算熵:熵代表集合的无序程度,集合越无序,熵越大;
- def entropy(dataset):
- from math import log
- log2 = lambda x:log(x)/log(2)
- results={}
- for row in dataset:
- r = row[len(row)-1]
- results[r] = results.get(r, 0) + 1
- ent = 0.0
- for r in results.keys():
- p = float(results[r]) / len(dataset)
- ent=ent-p*log2(p)
- return ent
(2)按属性和值获取数据集:
- def fetch_subdataset(dataset, k, v):
- return [d[:k]+d[k+1:] for d in dataset if d[k] == v]
(3)计算最大概率属性。在构建决策树时,在处理所有决策属性后,还不能唯一区分数据时,我们采用多数表决的方法来选择最终分类:
- def get_max_feature(class_list):
- class_count = {}
- for cla in class_list:
- class_count[cla] = class_count.get(cla, 0) + 1
- sorted_class_count = sorted(class_count.items(), key=lambda d: d[1], reverse=True)
- return sorted_class_count[0][0]
(二)选取最优数据划分方式函数:
选择集合的最优划分方式:以哪一列的值划分集合,从而获得最大的信息增益呢?
- def choose_decision_feature(dataset):
- ent, feature = 100000000, -1
- for i in range(len(dataset[0]) - 1):
- feat_list = [e[i] for e in dataset]
- unq_feat_list = set(feat_list)
- ent_t = 0.0
- for f in unq_feat_list:
- sub_data = fetch_subdataset(dataset, i, f)
- ent_t += entropy(sub_data) * len(sub_data) / len(dataset)
- if ent_t < ent:
- ent, feature = ent_t, i
- return feature
(三)递归构建决策树:
- def build_decision_tree(dataset, datalabel):
- cla = [c[-1] for c in dataset]
- if len(cla) == cla.count(cla[0]):
- return cla[0]
- if len(dataset[0]) == 1:
- return get_max_feature(dataset)
- feature = choose_decision_feature(dataset)
- feature_label = datalabel[feature]
- decision_tree = {feature_label:{}}
- del(datalabel[feature])
- feat_value = [d[feature] for d in dataset]
- unique_feat_value = set(feat_value)
- for value in unique_feat_value:
- sub_label = datalabel[:]
- decision_tree[feature_label][value] = build_decision_tree(\
- fetch_subdataset(dataset, feature, value), sub_label)
- return decision_tree
(四)使用决策树
- def classify(decision_tree, feat_labels, testVec):
- label = decision_tree.keys()[0]
- next_dict = decision_tree[label]
- feat_index = feat_labels.index(label)
- for key in next_dict.keys():
- if testVec[feat_index] == key:
- if type(next_dict[key]).__name__ == 'dict':
- c_label = classify(next_dict[key], feat_labels, testVec)
- else:
- c_label = next_dict[key]
- return c_label
(五)决策树持久化
(1)存储
- def store_decision_tree(tree, filename):
- import pickle
- f = open(filename, 'w')
- pickle.dump(tree, f)
- f.close()
(2)读取
- def load_decision_tree(filename):
- import pickle
- f = open(filename)
- return pickle.load(f)
(六)到了最后了,该回到主题了,给眼镜男配眼镜了。
下面的隐形眼镜数据集来自UCI数据库,它包含很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型,隐形眼镜类型包括硬材料、软材料和不适合佩戴隐形眼镜。
数据如下:
- young myope no reduced no lenses
- young myope no normal soft
- young myope yes reduced no lenses
- young myope yes normal hard
- young hyper no reduced no lenses
- young hyper no normal soft
- young hyper yes reduced no lenses
- young hyper yes normal hard
- pre myope no reduced no lenses
- pre myope no normal soft
- pre myope yes reduced no lenses
- pre myope yes normal hard
- pre hyper no reduced no lenses
- pre hyper no normal soft
- pre hyper yes reduced no lenses
- pre hyper yes normal no lenses
- presbyopic myope no reduced no lenses
- presbyopic myope no normal no lenses
- presbyopic myope yes reduced no lenses
- presbyopic myope yes normal hard
- presbyopic hyper no reduced no lenses
- presbyopic hyper no normal soft
- presbyopic hyper yes reduced no lenses
- presbyopic hyper yes normal no lenses
- def test():
- f = open('lenses.txt')
- lense_data = [inst.strip().split('\t') for inst in f.readlines()]
- lense_label = ['age', 'prescript', 'astigmatic', 'tearRate']
- lense_tree = build_decision_tree(lense_data, lense_label)
眼镜男终于可以买到合适的眼镜啦。。。
所有代码黏在下面:
- def entropy(dataset):
- from math import log
- log2 = lambda x:log(x)/log(2)
- results={}
- for row in dataset:
- r = row[len(row)-1]
- results[r] = results.get(r, 0) + 1
- ent = 0.0
- for r in results.keys():
- p = float(results[r]) / len(dataset)
- ent=ent-p*log2(p)
- return ent
- def fetch_subdataset(dataset, k, v):
- return [d[:k]+d[k+1:] for d in dataset if d[k] == v]
- def get_max_feature(class_list):
- class_count = {}
- for cla in class_list:
- class_count[cla] = class_count.get(cla, 0) + 1
- sorted_class_count = sorted(class_count.items(), key=lambda d: d[1], reverse=True)
- return sorted_class_count[0][0]
- def choose_decision_feature(dataset):
- ent, feature = 100000000, -1
- for i in range(len(dataset[0]) - 1):
- feat_list = [e[i] for e in dataset]
- unq_feat_list = set(feat_list)
- ent_t = 0.0
- for f in unq_feat_list:
- sub_data = fetch_subdataset(dataset, i, f)
- ent_t += entropy(sub_data) * len(sub_data) / len(dataset)
- if ent_t < ent:
- ent, feature = ent_t, i
- return feature
- def build_decision_tree(dataset, datalabel):
- cla = [c[-1] for c in dataset]
- if len(cla) == cla.count(cla[0]):
- return cla[0]
- if len(dataset[0]) == 1:
- return get_max_feature(dataset)
- feature = choose_decision_feature(dataset)
- feature_label = datalabel[feature]
- decision_tree = {feature_label:{}}
- del(datalabel[feature])
- feat_value = [d[feature] for d in dataset]
- unique_feat_value = set(feat_value)
- for value in unique_feat_value:
- sub_label = datalabel[:]
- decision_tree[feature_label][value] = build_decision_tree(\
- fetch_subdataset(dataset, feature, value), sub_label)
- return decision_tree
- def store_decision_tree(tree, filename):
- import pickle
- f = open(filename, 'w')
- pickle.dump(tree, f)
- f.close()
- def load_decision_tree(filename):
- import pickle
- f = open(filename)
- return pickle.load(f)
- def classify(decision_tree, feat_labels, testVec):
- label = decision_tree.keys()[0]
- next_dict = decision_tree[label]
- feat_index = feat_labels.index(label)
- for key in next_dict.keys():
- if testVec[feat_index] == key:
- if type(next_dict[key]).__name__ == 'dict':
- c_label = classify(next_dict[key], feat_labels, testVec)
- else:
- c_label = next_dict[key]
- return c_label
- def test():
- f = open('lenses.txt')
- lense_data = [inst.strip().split('\t') for inst in f.readlines()]
- lense_label = ['age', 'prescript', 'astigmatic', 'tearRate']
- lense_tree = build_decision_tree(lense_data, lense_label)
- return lense_tree
- if __name__ == "__main__":
- tree = test()
- print tree