通常信息增益越大,则意味着使用属性a划分所获得的样本集合的综合纯度越大。
ID3决策树和CART决策树
from sklearn.datasets import load_iris
import numpy as np
import math
from collections import Counter
class decisionnode:
def __init__(self, d=None, thre=None, results=None, NH=None, lb=None, rb=None, max_label=None):
self.d = d # d表示维度
self.thre = thre # thre表示二分时的比较值,将样本集分为2类
self.results = results # 最后的叶节点代表的类别
self.NH = NH # 存储各节点的样本量与经验熵的乘积,便于剪枝时使用
self.lb = lb # desision node,对应于样本在d维的数据小于thre时,树上相对于当前节点的子树上的节点
self.rb = rb # desision node,对应于样本在d维的数据大于thre时,树上相对于当前节点的子树上的节点
self.max_label = max_label # 记录当前节点包含的label中同类最多的label
def entropy(y):
'''
计算信息熵,y为labels
'''
if y.size > 1:
category = list(set(y))
else:
category = [y.item()]
y = [y.item()]
ent = 0
for label in category:
p = len([label_ for label_ in y if label_ == label]) / len(y)
ent += -p * math.log(p, 2)
return ent
def Gini(y):
'''
计算基尼指数,y为labels
'''
category = list(set(y))
gini = 1
for label in category:
p = len([label_ for label_ in y if label_ == label]) / len(y)
gini += -p * p
return gini
def GainEnt_max(X, y, d):
'''
计算选择属性attr的最大信息增益,X为样本集,y为label,d为一个维度,type为int
'''
ent_X = entropy(y)
X_attr = X[:, d]
X_attr = list(set(X_attr))
X_attr = sorted(X_attr)
Gain = 0
thre = 0
for i in range(len(X_attr) - 1):
thre_temp = (X_attr[i] + X_attr[i + 1]) / 2
y_small_index = [i_arg for i_arg in range(
len(X[:, d])) if X[i_arg, d] <= thre_temp]
y_big_index = [i_arg for i_arg in range(
len(X[:, d])) if X[i_arg, d] > thre_temp]
y_small = y[y_small_index]
y_big = y[y_big_index]
Gain_temp = ent_X - (len(y_small) / len(y)) * \
entropy(y_small) - (len(y_big) / len(y)) * entropy(y_big)
'''
intrinsic_value = -(len(y_small) / len(y)) * math.log(len(y_small) /
len(y), 2) - (len(y_big) / len(y)) * math.log(len(y_big) / len(y), 2)
Gain_temp = Gain_temp / intrinsic_value
'''
# print(Gain_temp)
if Gain < Gain_temp:
Gain = Gain_temp
thre = thre_temp
return Gain, thre
def Gini_index_min(X, y, d):
'''
计算选择属性attr的最小基尼指数,X为样本集,y为label,d为一个维度,type为int
'''
X = X.reshape(-1, len(X.T))
X_attr = X[:, d]
X_attr = list(set(X_attr))
X_attr = sorted(X_attr)
Gini_index = 1
thre = 0
for i in range(len(X_attr) - 1):
thre_temp = (X_attr[i] + X_attr[i + 1]) / 2
y_small_index = [i_arg for i_arg in range(
len(X[:, d])) if X[i_arg, d] <= thre_temp]
y_big_index = [i_arg for i_arg in range(
len(X[:, d])) if X[i_arg, d] > thre_temp]
y_small = y[y_small_index]
y_big = y[y_big_index]
Gini_index_temp = (len(y_small) / len(y)) * \
Gini(y_small) + (len(y_big) / len(y)) * Gini(y_big)
if Gini_index > Gini_index_temp:
Gini_index = Gini_index_temp
thre = thre_temp
return Gini_index, thre
def attribute_based_on_GainEnt(X, y):
'''
基于信息增益选择最优属性,X为样本集,y为label
'''
D = np.arange(len(X[0]))
Gain_max = 0
thre_ = 0
d_ = 0
for d in D:
Gain, thre = GainEnt_max(X, y, d)
if Gain_max < Gain:
Gain_max = Gain
thre_ = thre
d_ = d # 维度标号
return Gain_max, thre_, d_
def attribute_based_on_Giniindex(X, y):
'''
基于信息增益选择最优属性,X为样本集,y为label
'''
D = np.arange(len(X.T))
Gini_Index_Min = 1
thre_ = 0
d_ = 0
for d in D:
Gini_index, thre = Gini_index_min(X, y, d)
if Gini_Index_Min > Gini_index:
Gini_Index_Min = Gini_index
thre_ = thre
d_ = d # 维度标号
return Gini_Index_Min, thre_, d_
def devide_group(X, y, thre, d):
'''
按照维度d下阈值为thre分为两类并返回
'''
X_in_d = X[:, d]
x_small_index = [i_arg for i_arg in range(
len(X[:, d])) if X[i_arg, d] <= thre]
'''
以上等价于
x_small_index = []
for i_arg in range(len(X[:, d])):
if X[i_arg, d] <= thre:
x_small_index.append(i_arg)
'''
x_big_index = [i_arg for i_arg in range(
len(X[:, d])) if X[i_arg, d] > thre]
X_small = X[x_small_index]
y_small = y[x_small_index]
X_big = X[x_big_index]
y_big = y[x_big_index]
return X_small, y_small, X_big, y_big
def NtHt(y):
'''
计算经验熵与样本数的乘积,用来剪枝,y为labels
'''
ent = entropy(y)
print('ent={},y_len={},all={}'.format(ent, len(y), ent * len(y)))
return ent * len(y)
def maxlabel(y):
label_ = Counter(y).most_common(1)
return label_[0][0]
def buildtree(X, y, method='Gini'):
'''
递归的方式构建决策树
'''
if y.size > 1:
if method == 'Gini':
Gain_max, thre, d = attribute_based_on_Giniindex(X, y)
elif method == 'GainEnt':
Gain_max, thre, d = attribute_based_on_GainEnt(X, y)
if (Gain_max > 0 and method == 'GainEnt') or (Gain_max >= 0 and len(list(set(y))) > 1 and method == 'Gini'):
X_small, y_small, X_big, y_big = devide_group(X, y, thre, d)
left_branch = buildtree(X_small, y_small, method=method)
right_branch = buildtree(X_big, y_big, method=method)
nh = NtHt(y)
max_label = maxlabel(y)
return decisionnode(d=d, thre=thre, NH=nh, lb=left_branch, rb=right_branch, max_label=max_label)
else:
nh = NtHt(y)
max_label = maxlabel(y)
return decisionnode(results=y[0], NH=nh, max_label=max_label)
else:
nh = NtHt(y)
max_label = maxlabel(y)
return decisionnode(results=y.item(), NH=nh, max_label=max_label)
def printtree(tree, indent='-', dict_tree={}, direct='L'):
# 是否是叶节点
if tree.results != None:
print(tree.results)
dict_tree = {direct: str(tree.results)}
else:
# 打印判断条件
print(str(tree.d) + ":" + str(tree.thre) + "? ")
# 打印分支
print(indent + "L->",)
a = printtree(tree.lb, indent=indent + "-", direct='L')
aa = a.copy()
print(indent + "R->",)
b = printtree(tree.rb, indent=indent + "-", direct='R')
bb = b.copy()
aa.update(bb)
stri = str(tree.d) + ":" + str(tree.thre) + "?"
if indent != '-':
dict_tree = {direct: {stri: aa}}
else:
dict_tree = {stri: aa}
return dict_tree
def classify(observation, tree):
if tree.results != None:
return tree.results
else:
v = observation[tree.d]
branch = None
if v > tree.thre:
branch = tree.rb
else:
branch = tree.lb
return classify(observation, branch)
def pruning(tree, alpha=0.1):
if tree.lb.results == None:
pruning(tree.lb, alpha)
if tree.rb.results == None:
pruning(tree.rb, alpha)
if tree.lb.results != None and tree.rb.results != None:
before_pruning = tree.lb.NH + tree.rb.NH + 2 * alpha
after_pruning = tree.NH + alpha
print('before_pruning={},after_pruning={}'.format(
before_pruning, after_pruning))
if after_pruning <= before_pruning:
print('pruning--{}:{}?'.format(tree.d, tree.thre))
tree.lb, tree.rb = None, None
tree.results = tree.max_label
if __name__ == '__main__':
iris = load_iris()
X = iris.data
y = iris.target
#对X.shape[0]间的数随机排序
permutation = np.random.permutation(X.shape[0])#X.shape[0]=150
shuffled_dataset = X[permutation, :]
shuffled_labels = y[permutation]
#训练集乱序
train_data = shuffled_dataset[:100, :]
train_label = shuffled_labels[:100]
test_data = shuffled_dataset[100:150, :]
test_label = shuffled_labels[100:150]
tree1 = buildtree(train_data, train_label, method='Gini')
print('=============================')
tree2 = buildtree(train_data, train_label, method='GainEnt')
a = printtree(tree=tree1)
b = printtree(tree=tree2)
true_count = 0
for i in range(len(test_label)):
predict = classify(test_data[i], tree1)
if predict == test_label[i]:
true_count += 1
print("CARTTree:{}".format(true_count))
true_count = 0
for i in range(len(test_label)):
predict = classify(test_data[i], tree2)
if predict == test_label[i]:
true_count += 1
print("C3Tree:{}".format(true_count))
#print(attribute_based_on_Giniindex(X[49:51, :], y[49:51]))
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体
mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题
import treePlotter
import matplotlib.pyplot as plt
treePlotter.createPlot(a, 1)
treePlotter.createPlot(b, 2)
# 剪枝处理
pruning(tree=tree1, alpha=4)
pruning(tree=tree2, alpha=4)
a = printtree(tree=tree1)
b = printtree(tree=tree2)
true_count = 0
for i in range(len(test_label)):
predict = classify(test_data[i], tree1)
if predict == test_label[i]:
true_count += 1
print("CARTTree:{}".format(true_count))
true_count = 0
for i in range(len(test_label)):
predict = classify(test_data[i], tree2)
if predict == test_label[i]:
true_count += 1
print("C3Tree:{}".format(true_count))
treePlotter.createPlot(a, 3)
treePlotter.createPlot(b, 4)
plt.show()
附上treePloter.py
可视化
import matplotlib.pyplot as plt
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="round4", color='#3366FF') # 定义判断结点形态
leafNode = dict(boxstyle="circle", color='#FF6633') # 定义叶结点形态
arrow_args = dict(arrowstyle="<-", color='g') # 定义箭头
# 绘制带箭头的注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 计算叶结点数
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 计算树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 在父子结点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center",
ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) /
2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt) # 在父子结点间填充文本信息
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制带箭头的注释
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff,
plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree, index=1):
fig = plt.figure(index, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
结果
ent=0.0,y_len=37,all=0.0
ent=0.0,y_len=1,all=0.0
ent=0.0,y_len=1,all=0.0
ent=1.0,y_len=2,all=2.0
ent=0.0,y_len=26,all=0.0
ent=0.22228483068568797,y_len=28,all=6.223975259199263
ent=0.0,y_len=1,all=0.0
ent=0.0,y_len=3,all=0.0
ent=0.8112781244591328,y_len=4,all=3.2451124978365313
ent=0.0,y_len=2,all=0.0
ent=1.0,y_len=6,all=6.0
ent=0.0,y_len=29,all=0.0
ent=0.4220005168831531,y_len=35,all=14.770018090910359
ent=0.9983636725938131,y_len=63,all=62.89691137341023
ent=1.579641206421168,y_len=100,all=157.9641206421168
ent=0.0,y_len=37,all=0.0
ent=0.0,y_len=1,all=0.0
ent=0.0,y_len=1,all=0.0
ent=1.0,y_len=2,all=2.0
ent=0.0,y_len=26,all=0.0
ent=0.22228483068568797,y_len=28,all=6.223975259199263
ent=0.0,y_len=1,all=0.0
ent=0.0,y_len=3,all=0.0
ent=0.8112781244591328,y_len=4,all=3.2451124978365313
ent=0.0,y_len=2,all=0.0
ent=1.0,y_len=6,all=6.0
ent=0.0,y_len=29,all=0.0
ent=0.4220005168831531,y_len=35,all=14.770018090910359
ent=0.9983636725938131,y_len=63,all=62.89691137341023
ent=1.579641206421168,y_len=100,all=157.9641206421168
2:2.6?
-L->
0
-R->
2:4.85?
–L->
0:5.05?
—L->
1:2.45?
----L->
1
----R->
2
—R->
1
–R->
3:1.75?
—L->
3:1.55?
----L->
2:4.95?
-----L->
1
-----R->
2
----R->
1
—R->
2
2:2.6?
-L->
0
-R->
2:4.85?
–L->
0:5.05?
—L->
1:2.45?
----L->
1
----R->
2
—R->
1
–R->
3:1.75?
—L->
3:1.55?
----L->
2:4.95?
-----L->
1
-----R->
2
----R->
1
—R->
2
CARTTree:47
C3Tree:47
before_pruning=8.0,after_pruning=6.0
pruning–1:2.45?
before_pruning=10.0,after_pruning=10.223975259199264
before_pruning=8.0,after_pruning=7.245112497836532
pruning–2:4.95?
before_pruning=11.245112497836532,after_pruning=10.0
pruning–3:1.55?
before_pruning=14.0,after_pruning=18.77001809091036
before_pruning=8.0,after_pruning=6.0
pruning–1:2.45?
before_pruning=10.0,after_pruning=10.223975259199264
before_pruning=8.0,after_pruning=7.245112497836532
pruning–2:4.95?
before_pruning=11.245112497836532,after_pruning=10.0
pruning–3:1.55?
before_pruning=14.0,after_pruning=18.77001809091036
2:2.6?
-L->
0
-R->
2:4.85?
–L->
0:5.05?
—L->
2
—R->
1
–R->
3:1.75?
—L->
1
—R->
2
2:2.6?
-L->
0
-R->
2:4.85?
–L->
0:5.05?
—L->
2
—R->
1
–R->
3:1.75?
—L->
1
—R->
2
CARTTree:45
C3Tree:45
决策树的剪枝策略
决策树的剪枝策略分为预剪枝和后剪枝
预剪枝
预剪枝就是边建立决策时边进行剪枝的操作。在决策树生成的过程中,对每个节点在划分前向首先进行估计,若当前节点的划分不能带来决策树泛化性能的提升,则停止划分并将当前节点标记为叶子节点。
预剪枝可以:限制树的深度,叶子节点个数,叶子节点的样本数,信息增益量等。
后剪枝
当建立完决策树后再进行剪枝操作。后剪枝是先从训练集生成一棵完整的决策树,然后自底向上地对非叶子节点进行考察,若将该节点对应的子树替换为叶子节点能够带来决策树泛化性能的提升,将该子树替换为叶子节点。
通过一定的衡量标准。这里讲的是CART算法的后剪枝方法——代价复杂度算法,即CCP算法。