决策树建模
构建决策树、显示决策树、决策树剪枝、处理缺失值(训练数据和预测数据中的)、决策树预测
案例:利用决策树进行分类问题(采用CART决策树)
假设已经有一份清洗好的数据,前四列表示属性及其值,最后一列为分类。数据示例:
slashdot,USA,yes,18,None
google,France,yes,23,Prem
digg,USA,yes,24,Basic
baidu,France,yes,23,Basic
利用这份数据做训练,对['google','France','no',22]做分类预测,并告知属于此分类的概率是多少?
另外,从这份样本数据中分析出哪些属性及其值对于分类有重要影响,即分类的重要判断因素是哪些?
思路逻辑图
Python代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#@Time: 2019-11-14 08:23
#@Author: gaoll
import time
import random
import math
import numpy as np
import matplotlib.pyplot as plt
#定义树节点类
class decisionnode():
def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
self.col = col
self.value = value
self.results = results
self.tb = tb
self.fb = fb
#定义拆分集合的方法
def divideset(rowset,column,value):
set1 = []
set2 = []
if isinstance(value,int) or isinstance(value,float):
split_func = lambda row:row[column] >= value
else:
split_func = lambda row:row[column] == value
for row in rowset:
if split_func(row):
set1.append(row)
else:
set2.append(row)
return (set1,set2)
#计算各条分支上的分类结果在总训练样本中的概率,即 当前分支上包含的分类的量/总训练样本量。
#涉及两个概率,一个是相对于整体训练数据而言,决策树的每个分类结果的概率。一个是相对于此节点而言,分到True Branch和False Branche的概率。区别在于分母。
#对分类结果计数
def get_category_counts(rowset):
counts = {}
for row in rowset:
cat = row[-1] #分类在最后一列
counts.setdefault(cat,0)
counts[cat] +=1
return counts
#相对于此次拆分而言,走不同分支的概率
def branch_cat_prob(rowset):
probs = {}
total = len(rowset)
counts = get_category_counts(rowset)
for cat in counts.keys():
probs.setdefault(cat,0)
probs[cat] = counts[cat]/total
return probs
#相对于整体而言,最终每个分类结果的概率
def results_cat_prob(rowset,total):
counts = {}
results = {}
counts = get_category_counts(rowset)
for cat in counts.keys():
results.setdefault(cat,0)
results[cat] = counts[cat] / total
return results
#定义基尼不纯度。集合中每一项分类结果随机分配到其他分类结果上出现的概率总和。即相对于一次拆分的各分类概率乘积之和。
def giniimpurity(rowset):
imp = 0.0
cat_probs = branch_cat_prob(rowset)
for cat1 in cat_probs.keys():
cat1_prob = cat_probs[cat1]
for cat2 in cat_probs.keys():
cat2_prob = cat_probs[cat2]
if cat2 != cat1:
imp += cat1_prob * cat2_prob
else:
continue
return imp
#定义熵。对集合上每个分类结果cat出现的概率p(cat),求 -p(cat)*log(2,p(cat)) 之和。结合y=-xlog(2,x)的图形更加容易理解熵的大小表示的集合混乱程度。
def entropy(rowset):
ent = 0.0
cat_probs = branch_cat_prob(rowset)
for cat in cat_probs.keys():
cat_prob = cat_probs[cat]
ent = ent - cat_prob * (math.log(cat_prob) / math.log(2))
return ent
#调用基尼不纯度和熵计算下原训练数据的混乱程度
#giniimpurity(train_data) #0.64
#entropy(train_data) #1.52
#开始构建决策树,。注意:子分支还是树。
def buildtree(rowset,T,scoref=entropy,mingain=0.01):
if len(rowset) == 0:
return decisionnode()
current_score = scoref(rowset) #当前集合的熵或系数
total = len(rowset)
#选择列、选择该列的数据值对集合进行拆分,拆分的依据选择imp最小或者熵最小,即信息增益最大的那个拆分方式。
#总共几列? 每列有多少个值可以选。
col_value_set = {}
for row in rowset:
for i in range(len(row) -1):
col_value_set.setdefault(i,[])
col_value_set[i].append(row[i]) #存放每列可供选择的值。
#记录最佳拆分方案
best_infogain = 0.0
best_col,best_value = -1,None #存放最佳拆分列和值。
results = None #无需拆分时存放 分类和概率数据
best_tb,best_fb = None,None #存放True和False分支集合
for col in col_value_set.keys():
for value in col_value_set[col]:
(set1,set2)=divideset(rowset,col,value)
weight = len(set1)/total
infogain = current_score - (weight*scoref(set1) + (1-weight)*scoref(set2)) #计算信息增益
if infogain >best_infogain:
best_infogain = infogain
(best_col,best_value )= (col,value)
(best_tb,best_fb) = (set1,set2)
if best_infogain > mingain:
#最佳拆分方案的信息增益大于mingain,对集合进行拆分,创建子分支
trueBranch = buildtree(best_tb,T)
falseBranch = buildtree(best_fb,T)
return decisionnode(col=best_col,value=best_value,tb=trueBranch,fb=falseBranch) #这里注意tb和fb的赋值是子树,不要赋值成列表best_tb,best_fb,否则后面无法对树结构进行迭代
else:
#最佳拆分方案的信息增益小于mingain,停止向下拆分。返回一个叶节点
cat_results = results_cat_prob(rowset,T)
return decisionnode(results=cat_results)
#对树进行浏览、展示
#显示方法1,树状图呈现判断条件和分支。选择点坐标打印节点判断条件,子节点和父节点之间连线。
#获取树的总宽度,确定需要多大的画布宽度,即最大横坐标
def getwidth(tree):
if tree.results !=None:
return 1
else:
return getwidth(tree.tb) + getwidth(tree.fb)
#获取树的总深度,确定需要多大的画布高度,即最大纵坐标
def getdepth(tree):
if tree.results!=None:
return 0
else:
return max(getdepth(tree.tb), getdepth(tree.fb)) + 1
#画树状图。(1)该节点坐标(x,y)。在该节点上加上文本说明(col,val),也可以给文本留一个长度和高度,如(0.2,0.1),那么标记文本的坐标为(x-0.2,y-0.1)
# (2)truebranch节点坐标(x-w1/2,y-2),falsebranch节点坐标(x+w2/2,y-2). 为子节点和父节点之间的连线。
def drawnode(fig,tree,x,y):
if tree.results == None:
#如果不是叶节点,打印判断条件和连线子节点
#打印判断信息
txt =':'.join([str(tree.col),str(tree.value)])
fig.text(x+0.1,y-0.1,txt)
#计算子分支的宽度
w1=getwidth(tree.tb)*2
w2=getwidth(tree.fb)*2
#true和false两个子节点的坐标位置分别为:
(x1,y1) = (x-w1/2,y-2) #true的子节点位于它的子分支多占总宽度的中心。距离父节点的横坐标位置为 x-w1/2,纵坐标固定好的比父节点低2个高度.
(x2,y2) = (x+w2/2,y-2)
#连接父节点和子节点
fig.plot([x,x1],[y,y1])
fig.plot([x,x2],[y,y2])
#绘制分支的节点
drawnode(fig,tree.tb,x1,y1)
drawnode(fig,tree.fb,x2,y2)
else:
#如果是叶节点,直接打印文本,不需要考虑子节点
txt = '\n'.join(['%s:%0.5f'%(cat,tree.results[cat]) for cat in tree.results.keys()])
fig.text(x+0.1,y-0.1,txt) #在打印分类及概率信息
def drawtree(tree,tree_jpg='decision_tree.jpg'):
width = getwidth(tree) *2 + 5
height = getdepth(tree)*2 + 2
fig = plt.figure(figsize=(width, height))
ax = fig.add_subplot(1,1,1)
ax.set_title('decision tree')
drawnode(ax,tree,width/2,height-1)
fig.savefig(tree_jpg)
plt.show()
#显示方法2,文本形式呈现判断条件和分支.不如树状图更直观
def printtree(tree,indent=''):
if tree.results!=None:
#打印叶节点
print(str(tree.results))
else:
#打印判断条件
print(str(tree.col) +':' +str(tree.value)+'?')
#打印true分支
print(indent+'T->' , printtree(tree.tb,indent+' '))
#打印false分支
print(indent+'F->' , printtree(tree.fb,indent+' '))
#printtree(tree)
#决策树剪枝,解决过拟合的问题。
#对具有相同父节点的节点进行检查,如果熵的增加量小于某个阈值,就对叶节点进行合并,合并成一个新的叶节点。
#剪枝的过程和前面提到的提前停止拆分的过程可以结合使用。情况:某次拆分对熵的降低不大,但下一次的拆分使熵大幅度的降低,所以一般停止拆分的阈值<剪枝的阈值。
def prune(tree,total,mingain=0.1,scoref=entropy):
#非叶节点进行剪枝操作
if tree.tb.results==None:
prune(tree.tb,total,mingain,scoref)
if tree.fb.results==None:
prune(tree.fb,total,mingain,scoref)
#如果一个节点的两个子节点都是叶节点,则判断叶节点是否需要合并.即检查合并后的集合的熵和这两个集合的熵的增益值。
#构造true和false两个分支的节点的集合,
if tree.tb.results!=None and tree.fb.results!=None:
set_tb = []
set_fb = []
for cat in tree.tb.results.keys():
count = math.ceil(total * tree.tb.results[cat])
set_tb += [cat]*count
for cat in tree.fb.results.keys():
count = math.ceil(total * tree.fb.results[cat])
set_fb += [cat] *count
#两个叶节点的集合合并
set_merge = set_tb + set_fb
#计算信息增益
delta = scoref(set_merge) -(scoref(set_tb)+scoref(set_fb))/2
if delta < mingain:
new_results = {}
for cat,value in tree.tb.results.items():
new_results.setdefault(cat,0)
new_results[cat]+=value
for cat,value in tree.fb.results.items():
new_results.setdefault(cat,0)
new_results[cat]+=value
#print(new_results)
tree.tb = None
tree.fb = None
tree.results=new_results
#利用决策树做分类预测
def classify(observation,tree):
if tree.results!=None:
return tree.results
else:
col_value = observation[tree.col] #观察列
value = tree.value
branch = None
if isinstance(value,int) or isinstance(value,float):
if col_value >=value:
branch=tree.tb
else:
branch=tree.fb
else:
if col_value ==value:
branch=tree.tb
else:
branch=tree.fb
return classify(observation,branch)
def classify_observations(observation_file,tree):
for line in open(observation_file,'r'):
observation = line.strip().split(',')
print(classify(observation,tree))
if __name__ == '__main__':
#加载数据
data_file ='decision_tree_example.txt'
train_data = [line.strip().split(',')for line in open(data_file,'r')]
#训练树
tree =buildtree(train_data,len(train_data)) #运行,tree中保存着一个经过训练的决策树
#绘制决策树图
drawtree(tree)
#决策树修剪枝
prune(tree=tree,total=len(train_data),mingain=1.0,scoref=entropy)
#修剪枝后的决策树绘图,与未修剪前的进行比较
drawtree(tree)
#利用决策树进行分类预测
classify(['google','France','yes',23,'Prem'],tree)
#Out[196]: {'Prem': 0.2}
classify_observations(data_file,tree)
结果展示
#1、决策树
#2、剪枝后的决策树
#3、预测结果
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'Prem': 0.2}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'Prem': 0.2}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'Prem': 0.2}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'None': 0.06666666666666667}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}
{'Basic': 0.06666666666666667}
{'None': 0.33333333333333337, 'Basic': 0.3333333333333333}