目录
✨主程序
✨画图程序
✨结果展示
✨参考
主程序
import os
from sklearn.linear_model import LogisticRegression
import numpy as np
import pandas as pd
import warnings
from createPlot import createPlot
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
# 定义连续值处理函数
def con_deal(temp_df, a):
for j in range(0, len(temp_df)):
temp_df.iat[j] = 0 if (temp_df.iat[j] < a) else 1
return temp_df
# 定义计算连续值正确率的函数
def con_acc(data, Y):
a = np.sort(np.array(data))
a = (a[0: len(a) - 1] + a[1: len(a)]) / 2
max_acc, ind = 0, 0
for i in range(0, len(a)):
temp_df = con_deal(data.copy(), a[i])
X0 = np.array(temp_df).reshape(-1, 1)
logreg = LogisticRegression()
logreg.fit(X0, Y)
acc = logreg.score(X0, Y)
if max_acc < acc:
max_acc = acc
ind = i
temp_df0 = X0
print(round(max_acc, 3), end=', 判断结果为:\n')
print(logreg.predict(temp_df0))
return [max_acc, a[ind]]
# 获取根节点函数
def getroot(X1, Y1, m):
max_acc = 0
for i in m:
if i != 'Density' and i != 'sugar':
print(i + '节点, 正确率为', end=':')
X0 = np.array(X1[i]).reshape(-1, 1)
logreg = LogisticRegression()
logreg.fit(X0, Y1)
acc = logreg.score(X0, Y1)
print(round(acc, 3), end=', 判断结果为:\n')
print(logreg.predict(X0))
if max_acc < acc:
max_acc = acc
root = i
else:
print(i + '节点, 正确率为', end=':')
acc = con_acc(X1[i], Y1)[0]
if max_acc < acc:
max_acc = acc
root = i
return root
# 获取决策树数组函数
def gettree(X, Xo, Y, m):
n1, n2 = [], []
root = getroot(X, Y['popular'], m)
print('故选择' + root + '为根节点')
n1.append(root)
m.remove(root)
if root == 'Density' or root == 'sugar':
div = con_acc(X[root], Y['popular'])[1]
X[root], Xo[root], Y[root] = con_deal(X[root], div), con_deal(Xo[root], div), con_deal(X[root], div)
# print(X, Xo)
Attr, Attro = X[root].unique(), Xo[root].unique()
# print(Attr, Attro)
for j, jo in zip(Attr, Attro):
n3 = []
if root == 'Density' or root == 'sugar' :
if j >= div:
key = '≥' + str(div)
else:
key = '<' + str(div)
else:
key = jo
print(root + '为' + key + '时:')
n3.append(key)
X1 = X[X[root] == j]
Xo1 = Xo[Xo[root] == jo]
Y0 = Y[Y[root] == j]
Y1 = Y0['popular']
if Y1.unique().size > 1:
Xn, Xon, Yn = X1, Xo1, Y0
n3.append(gettree(Xn, Xon, Yn, m))
else:
flag = 'popular' if Y1.unique() == '是' else 'bad'
print(flag)
n3.append(flag)
n2.append(n3)
n1 += n2
return n1
# 数组处理及绘制函数
def dealanddraw(n0, pngname):
alstr = str(n0)
alstr = alstr.replace(',', ':');
alstr = alstr.replace(']: [', ',')
alstr = alstr.replace(']:', '],')
alstr = alstr.replace('[', '{');
alstr = alstr.replace(']', '}')
inTree = eval(alstr)
# print(inTree)
plt.figure(figsize=(10, 7))
createPlot(inTree)
# dpi, 控制每英寸长度上的分辨率;bbox_inches, 能删除figure周围的空白部分
plt.savefig(pngname, dpi=400, bbox_inches='tight')
f = open('watermelon3.txt')
watermelon3_df = pd.read_table(f)
Xo = watermelon3_df[['color', 'root', 'knock', 'texture', 'navel', 'tactility', 'Density', 'sugar']]
m = list(watermelon3_df.columns)
h = 0.001
for i in m:
if i != 'Density' and i != 'sugar' and i != 'popular':
size_mapping = {}
m0 = watermelon3_df[i].unique()
j = 1
for i0 in m0:
size_mapping[i0] = j
j += 1
# print(size_mapping)
watermelon3_df[i] = watermelon3_df[i].map(size_mapping)
X = watermelon3_df[['color', 'root', 'knock', 'texture', 'navel', 'tactility', 'Density', 'sugar']]
Y = watermelon3_df
m = list(X.columns)
n0 = gettree(X, Xo, Y, m)
pngname = os.path.basename(os.path.realpath(__file__)).replace('py', 'png')
dealanddraw(n0, pngname)
画图程序
程序名:createPlot.py
import matplotlib.pyplot as plt
# 用来正常显示负号
plt.rcParams['axes.unicode_minus'] = False
# 设置画节点用的盒子的样式
decisionNode = dict(boxstyle="sawtooth", color='#f05b72')
leafNode = dict(boxstyle="round4", color='#826858')
# 设置画箭头的样式
arrow_args = dict(arrowstyle="<-", color='#121a2a')
def getNumLeafs(myTree):
# 初始化树的叶子节点个数
numLeafs = 0
# myTree.keys()获取树的非叶子节点'no surfacing'和'flippers'
# list(myTree.keys())[0]获取第一个键名'no surfacing'
firstStr = list(myTree.keys())[0]
# 通过键名获取与之对应的值,即{0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
secondDict = myTree[firstStr]
# 遍历树,secondDict.keys()获取所有的键
for key in secondDict.keys():
# 判断键是否为字典,键名1和其值就组成了一个字典,如果是字典则通过递归继续遍历,寻找叶子节点
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
# 如果不是字典,则叶子结点的数目就加1
else:
numLeafs += 1
# 返回叶子节点的数目
return numLeafs
def getTreeDepth(myTree):
# 初始化树的深度
maxDepth = 0
# 获取树的第一个键名
firstStr = list(myTree.keys())[0]
# 获取键名所对应的值
secondDict = myTree[firstStr]
# 遍历树
for key in secondDict.keys():
# 如果获取的键是字典,树的深度加1
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
# 去深度的最大值
if thisDepth > maxDepth: maxDepth = thisDepth
# 返回树的深度
return maxDepth
# 绘图相关参数的设置
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
'''
annotate函数是为绘制图上指定的数据点xy添加一个nodeTxt注释
nodeTxt是给数据点xy添加一个注释,xy为数据点的开始绘制的坐标,位于节点的中间位置
xycoords设置指定点xy的坐标类型,xytext为注释的中间点坐标,textcoords设置注释点坐标样式
bbox设置装注释盒子的样式,arrowprops设置箭头的样式
'''
'''
figure points:表示坐标原点在图的左下角的数据点
figure pixels:表示坐标原点在图的左下角的像素点
figure fraction:此时取值是小数,范围是([0,1],[0,1]),在图的左下角时xy是(0,0),最右上角是(1,1)
其他位置是按相对图的宽高的比例取最小值
axes points : 表示坐标原点在图中坐标的左下角的数据点
axes pixels : 表示坐标原点在图中坐标的左下角的像素点
axes fraction : 与figure fraction类似,只不过相对于图的位置改成是相对于坐标轴的位置
'''
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 绘制线中间的文字(0和1)的绘制
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] # 计算文字的x坐标
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] # 计算文字的y坐标
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=20)
# 绘制树
def plotTree(myTree, parentPt, nodeTxt):
# 获取树的叶子节点
numLeafs = getNumLeafs(myTree)
# 获取树的深度
depth = getTreeDepth(myTree)
# firstStr = myTree.keys()[0]
# 获取第一个键名
firstStr = list(myTree.keys())[0]
# 计算子节点的坐标
cntrPt = (plotTree.xoff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yoff)
# 绘制线上的文字
plotMidText(cntrPt, parentPt, nodeTxt)
# 绘制节点
plotNode(firstStr, cntrPt, parentPt, decisionNode)
# 获取第一个键值
secondDict = myTree[firstStr]
# 计算节点y方向上的偏移量,根据树的深度
plotTree.yoff = plotTree.yoff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
# 递归绘制树
plotTree(secondDict[key], cntrPt, str(key))
else:
# 更新x的偏移量,每个叶子结点x轴方向上的距离为 1/plotTree.totalW
plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalW
# 绘制非叶子节点
plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
# 绘制箭头上的标志
plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
plotTree.yoff = plotTree.yoff + 1.0 / plotTree.totalD
# 绘制决策树
def createPlot(inTree):
# 清除figure
plt.clf()
axprops = dict(xticks=[], yticks=[])
# 创建一个1行1列1个figure,并把网格里面的第一个figure的Axes实例返回给ax1作为函数createPlot()
# 的属性,这个属性ax1相当于一个全局变量,可以给plotNode函数使用
createPlot.ax1 = plt.subplot(frameon=False, **axprops)
# 获取树的叶子节点
plotTree.totalW = float(getNumLeafs(inTree))
# 获取树的深度
plotTree.totalD = float(getTreeDepth(inTree))
# 节点的x轴的偏移量为-1/plotTree.totlaW/2,1为x轴的长度,除以2保证每一个节点的x轴之间的距离为1/plotTree.totlaW*2
plotTree.xoff = -0.5 / plotTree.totalW
plotTree.yoff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
代码修改细节
(1)分别使用两个Python程序实现构造树,画树
(2)区分连续性数据以及离散型数据
如果使用的是连续性数据,在gettree方法中,下述语句or后面再加判断条件即可。
if root == 'Density' or root == 'sugar' :
如你想增加一个名为sugarRate的属性时,修改为下述语句即可
if root == 'Density' or root == 'sugar' or root=='sugarRate' :
若是想使用离散型数据,在txt数据集中直接修改即可,不需要修改代码。
(3) 关于txt
txt文件名为watermelon3.txt,popular是最后的判断属性,因为在主程序中判断是由‘是’/‘否’来判断的,所以我并没有修改。但发现画图程序是不可以显示中文的,所以使用时将所有中文属性改为英文。如下所示。
color root knock texture navel tactility Density sugar popular
green curl dull accurate sunken smooth 0.697 0.46 是
black curl depressing accurate sunken smooth 0.774 0.376 是
black curl dull accurate sunken smooth 0.634 0.264 是
green curl depressing accurate sunken smooth 0.608 0.318 是
white curl dull accurate sunken smooth 0.556 0.215 是
green LittleCurl dull accurate LittleSunken sticky 0.403 0.237 是
black LittleCurl dull LittleVague LittleSunken sticky 0.481 0.149 是
black LittleCurl dull accurate LittleSunken smooth 0.437 0.211 是
black LittleCurl depressing LittleVague LittleSunken smooth 0.666 0.091 否
green hard clear accurate even sticky 0.243 0.267 否
white hard clear vague even smooth 0.245 0.057 否
white curl dull vague even sticky 0.343 0.099 否
green LittleCurl dull LittleVague sunken smooth 0.639 0.161 否
white LittleCurl depressing LittleVague sunken smooth 0.657 0.198 否
black LittleCurl dull accurate LittleSunken sticky 0.36 0.37 否
white curl dull vague even smooth 0.593 0.042 否
green curl depressing LittleVague LittleSunken smooth 0.719 0.103 否
green curl dull accurate sunken smooth 0.697 0.46 是
black curl depressing accurate sunken smooth 0.774 0.376 是
black curl dull accurate sunken smooth 0.634 0.264 是
green curl depressing accurate sunken smooth 0.608 0.318 是
white curl dull accurate sunken smooth 0.556 0.215 是
green LittleCurl dull accurate LittleSunken sticky 0.403 0.237 是
black LittleCurl dull LittleVague LittleSunken sticky 0.481 0.149 是
black LittleCurl dull accurate LittleSunken smooth 0.437 0.211 是
black LittleCurl depressing LittleVague LittleSunken smooth 0.666 0.091 否
green hard clear accurate even sticky 0.243 0.267 否
white hard clear vague even smooth 0.245 0.057 否
white curl dull vague even sticky 0.343 0.099 否
green LittleCurl dull LittleVague sunken smooth 0.639 0.161 否
white LittleCurl depressing LittleVague sunken smooth 0.657 0.198 否
black LittleCurl dull accurate LittleSunken sticky 0.36 0.37 否
white curl dull vague even smooth 0.593 0.042 否
green curl depressing LittleVague LittleSunken smooth 0.719 0.103 否
green curl dull accurate sunken smooth 0.697 0.46 是
black curl depressing accurate sunken smooth 0.774 0.376 是
black curl dull accurate sunken smooth 0.634 0.264 是
green curl depressing accurate sunken smooth 0.608 0.318 是
white curl dull accurate sunken smooth 0.556 0.215 是
green LittleCurl dull accurate LittleSunken sticky 0.403 0.237 是
black LittleCurl dull LittleVague LittleSunken sticky 0.481 0.149 是
black LittleCurl dull accurate LittleSunken smooth 0.437 0.211 是
black LittleCurl depressing LittleVague LittleSunken smooth 0.666 0.091 否
green hard clear accurate even sticky 0.243 0.267 否
white hard clear vague even smooth 0.245 0.057 否
white curl dull vague even sticky 0.343 0.099 否
green LittleCurl dull LittleVague sunken smooth 0.639 0.161 否
white LittleCurl depressing LittleVague sunken smooth 0.657 0.198 否
black LittleCurl dull accurate LittleSunken sticky 0.36 0.37 否
white curl dull vague even smooth 0.593 0.042 否
green curl depressing LittleVague LittleSunken smooth 0.719 0.103 否
结果展示

该博客介绍用Python实现决策树相关内容。包含主程序、画图程序,说明了代码修改细节,如分别用两个程序构造树和画树,区分连续性与离散型数据的处理方式,还提到txt文件使用及画图程序显示问题,最后给出参考文章。

被折叠的 条评论
为什么被折叠?



