基于信息增益的决策树分类是较为常见的一种分类方法,特征属性一般为标称型数据。
原理较为简单,这里不做推导。网上的程序许多是基于python2.x,我在这里将基于python3.6的程序列出来供大家参考。欢迎多多交流!
def create_dataset(): data_set = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] return data_set, labels
from math import log #计算信息增益 def cal_ent(dataset): num = len(dataset) label_counts = {} for vect in dataset: temp = vect[-1] if temp not in label_counts.keys(): label_counts[temp] = 0 label_counts[temp] += 1 for key, value in label_counts.items(): label_counts[key] = float(value) / num shannon_entropy = 0 for value in label_counts.values(): shannon_entropy -= value * log(value, 2) return shannon_entropy
from calculate_entropy import cal_ent from split_dataset import split_dataset #挑选属性,原则为信息增益最大 def select_feature(dataset): feature_entropy = {} num_1 = len(dataset[0]) num_2 = len(dataset) index = 0 max_gain = 0 entropy = cal_ent(dataset) for feature_position in range(0, num_1-1): feature_gain = 0 feature_list = [item[feature_position] for item in dataset] feature_set_1 = set(feature_list) for number in feature_set_1: feature_set = split_dataset(dataset, feature_position, number) feature_gain += len(feature_set)/num_2*cal_ent(feature_set) feature_entropy[feature_position] = entropy - feature_gain for value in feature_entropy.values(): if value > max_gain: max_gain = value index += 1 return index-1
def split_dataset(dataset, feature_position, feature_value): split_result = [] for vect in dataset: temp = vect[feature_position] if temp == feature_value: reducedfeatvec = vect[:feature_position] reducedfeatvec.extend(vect[feature_position+1:]) split_result.append(reducedfeatvec) return split_result
from majority import majority from select_feature import select_feature from split_dataset import split_dataset #创建树 def createtree(dataset,labels): classlist=[example[-1] for example in dataset] if classlist.count(classlist[0]) == len(classlist): return classlist[0] if len(dataset[0]) == 1: return majority(classlist) bestfeature = select_feature(dataset) bestfeaturelabel = labels[bestfeature] mytree={bestfeaturelabel: {}} del(labels[bestfeature]) feature_value = {example[bestfeature] for example in dataset} uniquevals = set(feature_value) for value in uniquevals: sublabels = labels[:] mytree[bestfeaturelabel][value] = createtree(split_dataset(dataset, bestfeature, value),sublabels) return mytree
import matplotlib.pyplot as pyplot import operator #将创建的树可视化 decision_node = dict(boxstyle="sawtooth", fc="0.8") leafnode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) def getNumLeafs(myTree): numLeafs=0 #初始化结点数 # 下面三行为代码 python3 替换注释的两行代码 firstSides = list(myTree.keys()) firstStr = firstSides[0] # 找到输入的第一个元素,第一个关键词为划分数据集类别的标签 secondDict = myTree[firstStr] #firstStr = myTree.keys()[0] #secondDict=myTree[firstStr] for key in secondDict.keys(): #测试数据是否为字典形式 if type(secondDict[key]).__name__ == 'dict': #type判断子结点是否为字典类型 numLeafs += getNumLeafs(secondDict[key]) #若子节点也为字典,则也是判断结点,需要递归获取num else: numLeafs += 1 return numLeafs #返回整棵树的结点数 def cal_treedepth(tree): maxdepth=0 first_key = list(tree.keys()) key_name = first_key[0] seconddict = tree[key_name] for key in seconddict.keys(): if type(seconddict[key]).__name__ == 'dict': this_depth = 1 + cal_treedepth(seconddict[key]) else: this_depth = 1 if this_depth >maxdepth: maxdepth = this_depth return maxdepth def retrieveTree(i): listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return listOfTrees[i] def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=0) def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) #计算树的宽度 totalW depth = cal_treedepth(myTree) #计算树的高度 存储在totalD #python3.x修改 firstSides = list(myTree.keys())#firstStr = myTree.keys()[0] #the text label for this node should be this firstStr = firstSides[0] # 找到输入的第一个元素 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)#按照叶子结点个数划分x轴 plotMidText(cntrPt, parentPt, nodeTxt) #标注结点属性 plotNode(firstStr, cntrPt, parentPt, decision_node) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #y方向上的摆放位置 自上而下绘制,因此递减y值 for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict':#判断是否为字典 不是则为叶子结点 plotTree(secondDict[key], cntrPt, str(key)) #递归继续向下找 else: #为叶子结点 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW #x方向计算结点坐标 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafnode)#绘制 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))#添加文本信息 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #下次重新调用时恢复y def createPlot(inTree): fig = pyplot.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = pyplot.subplot(111, frameon=False, **axprops) # no ticks # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(cal_treedepth(inTree)) plotTree.xOff = -0.5 / plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') pyplot.show()
from create_dataset import create_dataset from createtree import createtree from tree_plot import * #主函数 def classify(tree, labels, vect): keylist = list(tree.keys()) keyname = keylist[0] position_index = labels.index(keyname) secondtree = tree[keyname] global temp for key in secondtree.keys(): if key == vect[position_index]: if type(secondtree[key]).__name__ == 'dict': classify(secondtree[key], labels, vect) else: temp = secondtree[key] return temp data_set, labels = create_dataset() vect = [1, 1] label_list = ['no surfacing', 'flippers'] my_tree = createtree(data_set, labels) createPlot(my_tree) print(my_tree) print(classify(my_tree, label_list, vect))
最后的vect和label_list为测试向量,用于测试分类树。
运行结果如下:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
yes