Python生成对率回归决策树

目录

主程序

画图程序 

代码修改细节

结果展示

参考 


主程序

import os
from sklearn.linear_model import LogisticRegression
import numpy as np
import pandas as pd
import warnings
from createPlot import createPlot
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")


# 定义连续值处理函数
def con_deal(temp_df, a):
    for j in range(0, len(temp_df)):
        temp_df.iat[j] = 0 if (temp_df.iat[j] < a) else 1
    return temp_df


# 定义计算连续值正确率的函数
def con_acc(data, Y):
    a = np.sort(np.array(data))
    a = (a[0: len(a) - 1] + a[1: len(a)]) / 2
    max_acc, ind = 0, 0
    for i in range(0, len(a)):
        temp_df = con_deal(data.copy(), a[i])
        X0 = np.array(temp_df).reshape(-1, 1)
        logreg = LogisticRegression()
        logreg.fit(X0, Y)
        acc = logreg.score(X0, Y)
        if max_acc < acc:
            max_acc = acc
            ind = i
            temp_df0 = X0
    print(round(max_acc, 3), end=', 判断结果为:\n')
    print(logreg.predict(temp_df0))
    return [max_acc, a[ind]]


# 获取根节点函数
def getroot(X1, Y1, m):
    max_acc = 0
    for i in m:
        if i != 'Density' and i != 'sugar':
            print(i + '节点, 正确率为', end=':')
            X0 = np.array(X1[i]).reshape(-1, 1)
            logreg = LogisticRegression()
            logreg.fit(X0, Y1)
            acc = logreg.score(X0, Y1)
            print(round(acc, 3), end=', 判断结果为:\n')
            print(logreg.predict(X0))
            if max_acc < acc:
                max_acc = acc
                root = i
        else:
            print(i + '节点, 正确率为', end=':')
            acc = con_acc(X1[i], Y1)[0]
            if max_acc < acc:
                max_acc = acc
                root = i
    return root


# 获取决策树数组函数
def gettree(X, Xo, Y, m):
    n1, n2 = [], []
    root = getroot(X, Y['popular'], m)
    print('故选择' + root + '为根节点')
    n1.append(root)
    m.remove(root)
    if root == 'Density' or root == 'sugar':
        div = con_acc(X[root], Y['popular'])[1]
        X[root], Xo[root], Y[root] = con_deal(X[root], div), con_deal(Xo[root], div), con_deal(X[root], div)
    #    print(X, Xo)
    Attr, Attro = X[root].unique(), Xo[root].unique()
    #    print(Attr, Attro)
    for j, jo in zip(Attr, Attro):
        n3 = []
        if root == 'Density' or root == 'sugar' :
            if j >= div:
                key = '≥' + str(div)
            else:
                key = '<' + str(div)
        else:
            key = jo
        print(root + '为' + key + '时:')
        n3.append(key)
        X1 = X[X[root] == j]
        Xo1 = Xo[Xo[root] == jo]
        Y0 = Y[Y[root] == j]
        Y1 = Y0['popular']
        if Y1.unique().size > 1:
            Xn, Xon, Yn = X1, Xo1, Y0
            n3.append(gettree(Xn, Xon, Yn, m))
        else:
            flag = 'popular' if Y1.unique() == '是' else 'bad'
            print(flag)
            n3.append(flag)
        n2.append(n3)
    n1 += n2
    return n1


# 数组处理及绘制函数
def dealanddraw(n0, pngname):
    alstr = str(n0)
    alstr = alstr.replace(',', ':');
    alstr = alstr.replace(']: [', ',')
    alstr = alstr.replace(']:', '],')
    alstr = alstr.replace('[', '{');
    alstr = alstr.replace(']', '}')
    inTree = eval(alstr)
    #    print(inTree)
    plt.figure(figsize=(10, 7))
    createPlot(inTree)
    #    dpi, 控制每英寸长度上的分辨率;bbox_inches, 能删除figure周围的空白部分
    plt.savefig(pngname, dpi=400, bbox_inches='tight')


f = open('watermelon3.txt')
watermelon3_df = pd.read_table(f)
Xo = watermelon3_df[['color', 'root', 'knock', 'texture', 'navel', 'tactility', 'Density', 'sugar']]
m = list(watermelon3_df.columns)
h = 0.001
for i in m:
    if i != 'Density' and i != 'sugar' and i != 'popular':
        size_mapping = {}
        m0 = watermelon3_df[i].unique()
        j = 1
        for i0 in m0:
            size_mapping[i0] = j
            j += 1
        #        print(size_mapping)
        watermelon3_df[i] = watermelon3_df[i].map(size_mapping)

X = watermelon3_df[['color', 'root', 'knock', 'texture', 'navel', 'tactility', 'Density', 'sugar']]
Y = watermelon3_df
m = list(X.columns)
n0 = gettree(X, Xo, Y, m)
pngname = os.path.basename(os.path.realpath(__file__)).replace('py', 'png')
dealanddraw(n0, pngname)

画图程序 

程序名:createPlot.py

import matplotlib.pyplot as plt

# 用来正常显示负号
plt.rcParams['axes.unicode_minus'] = False
# 设置画节点用的盒子的样式
decisionNode = dict(boxstyle="sawtooth", color='#f05b72')
leafNode = dict(boxstyle="round4", color='#826858')
# 设置画箭头的样式
arrow_args = dict(arrowstyle="<-", color='#121a2a')


def getNumLeafs(myTree):
    # 初始化树的叶子节点个数
    numLeafs = 0
    # myTree.keys()获取树的非叶子节点'no surfacing'和'flippers'
    # list(myTree.keys())[0]获取第一个键名'no surfacing'
    firstStr = list(myTree.keys())[0]
    # 通过键名获取与之对应的值,即{0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
    secondDict = myTree[firstStr]
    # 遍历树,secondDict.keys()获取所有的键
    for key in secondDict.keys():
        # 判断键是否为字典,键名1和其值就组成了一个字典,如果是字典则通过递归继续遍历,寻找叶子节点
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        # 如果不是字典,则叶子结点的数目就加1
        else:
            numLeafs += 1
    # 返回叶子节点的数目
    return numLeafs


def getTreeDepth(myTree):
    # 初始化树的深度
    maxDepth = 0
    # 获取树的第一个键名
    firstStr = list(myTree.keys())[0]
    # 获取键名所对应的值
    secondDict = myTree[firstStr]
    # 遍历树
    for key in secondDict.keys():
        # 如果获取的键是字典,树的深度加1
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        # 去深度的最大值
        if thisDepth > maxDepth: maxDepth = thisDepth
    # 返回树的深度
    return maxDepth


# 绘图相关参数的设置
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    '''
    annotate函数是为绘制图上指定的数据点xy添加一个nodeTxt注释
    nodeTxt是给数据点xy添加一个注释,xy为数据点的开始绘制的坐标,位于节点的中间位置
    xycoords设置指定点xy的坐标类型,xytext为注释的中间点坐标,textcoords设置注释点坐标样式
    bbox设置装注释盒子的样式,arrowprops设置箭头的样式
    '''
    '''
    figure points:表示坐标原点在图的左下角的数据点
    figure pixels:表示坐标原点在图的左下角的像素点
    figure fraction:此时取值是小数,范围是([0,1],[0,1]),在图的左下角时xy是(0,0),最右上角是(1,1)
    其他位置是按相对图的宽高的比例取最小值
    axes points : 表示坐标原点在图中坐标的左下角的数据点
    axes pixels : 表示坐标原点在图中坐标的左下角的像素点
    axes fraction : 与figure fraction类似,只不过相对于图的位置改成是相对于坐标轴的位置
    '''
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# 绘制线中间的文字(0和1)的绘制
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]  # 计算文字的x坐标
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]  # 计算文字的y坐标
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=20)


# 绘制树
def plotTree(myTree, parentPt, nodeTxt):
    # 获取树的叶子节点
    numLeafs = getNumLeafs(myTree)
    # 获取树的深度
    depth = getTreeDepth(myTree)
    # firstStr = myTree.keys()[0]
    # 获取第一个键名
    firstStr = list(myTree.keys())[0]
    # 计算子节点的坐标
    cntrPt = (plotTree.xoff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yoff)
    # 绘制线上的文字
    plotMidText(cntrPt, parentPt, nodeTxt)
    # 绘制节点
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    # 获取第一个键值
    secondDict = myTree[firstStr]
    # 计算节点y方向上的偏移量,根据树的深度
    plotTree.yoff = plotTree.yoff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            # 递归绘制树
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            # 更新x的偏移量,每个叶子结点x轴方向上的距离为 1/plotTree.totalW
            plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalW
            # 绘制非叶子节点
            plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
            # 绘制箭头上的标志
            plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
    plotTree.yoff = plotTree.yoff + 1.0 / plotTree.totalD


# 绘制决策树
def createPlot(inTree):
    # 清除figure
    plt.clf()
    axprops = dict(xticks=[], yticks=[])
    # 创建一个1行1列1个figure,并把网格里面的第一个figure的Axes实例返回给ax1作为函数createPlot()
    # 的属性,这个属性ax1相当于一个全局变量,可以给plotNode函数使用
    createPlot.ax1 = plt.subplot(frameon=False, **axprops)
    # 获取树的叶子节点
    plotTree.totalW = float(getNumLeafs(inTree))
    # 获取树的深度
    plotTree.totalD = float(getTreeDepth(inTree))
    # 节点的x轴的偏移量为-1/plotTree.totlaW/2,1为x轴的长度,除以2保证每一个节点的x轴之间的距离为1/plotTree.totlaW*2
    plotTree.xoff = -0.5 / plotTree.totalW
    plotTree.yoff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

代码修改细节

(1)分别使用两个Python程序实现构造树,画树

(2)区分连续性数据以及离散型数据

如果使用的是连续性数据,在gettree方法中,下述语句or后面再加判断条件即可

if root == 'Density' or root == 'sugar' :

如你想增加一个名为sugarRate的属性时,修改为下述语句即可 

if root == 'Density' or root == 'sugar' or root=='sugarRate' :

若是想使用离散型数据,在txt数据集中直接修改即可,不需要修改代码。 

(3) 关于txt

txt文件名为watermelon3.txt,popular是最后的判断属性,因为在主程序中判断是由‘是’/‘否’来判断的,所以我并没有修改。但发现画图程序是不可以显示中文的,所以使用时将所有中文属性改为英文。如下所示。

color	root	knock	texture	navel	tactility	Density	sugar	popular
green	curl	dull	accurate	sunken	smooth	0.697	0.46	是
black	curl	depressing	accurate	sunken	smooth	0.774	0.376	是
black	curl	dull	accurate	sunken	smooth	0.634	0.264	是
green	curl	depressing	accurate	sunken	smooth	0.608	0.318	是
white	curl	dull	accurate	sunken	smooth	0.556	0.215	是
green	LittleCurl	dull	accurate	LittleSunken	sticky	0.403	0.237	是
black	LittleCurl	dull	LittleVague	LittleSunken	sticky	0.481	0.149	是
black	LittleCurl	dull	accurate	LittleSunken	smooth	0.437	0.211	是
black	LittleCurl	depressing	LittleVague	LittleSunken	smooth	0.666	0.091	否
green	hard	clear	accurate	even	sticky	0.243	0.267	否
white	hard	clear	vague	even	smooth	0.245	0.057	否
white	curl	dull	vague	even	sticky	0.343	0.099	否
green	LittleCurl	dull	LittleVague	sunken	smooth	0.639	0.161	否
white	LittleCurl	depressing	LittleVague	sunken	smooth	0.657	0.198	否
black	LittleCurl	dull	accurate	LittleSunken	sticky	0.36	0.37	否
white	curl	dull	vague	even	smooth	0.593	0.042	否
green	curl	depressing	LittleVague	LittleSunken	smooth	0.719	0.103	否
green	curl	dull	accurate	sunken	smooth	0.697	0.46	是
black	curl	depressing	accurate	sunken	smooth	0.774	0.376	是
black	curl	dull	accurate	sunken	smooth	0.634	0.264	是
green	curl	depressing	accurate	sunken	smooth	0.608	0.318	是
white	curl	dull	accurate	sunken	smooth	0.556	0.215	是
green	LittleCurl	dull	accurate	LittleSunken	sticky	0.403	0.237	是
black	LittleCurl	dull	LittleVague	LittleSunken	sticky	0.481	0.149	是
black	LittleCurl	dull	accurate	LittleSunken	smooth	0.437	0.211	是
black	LittleCurl	depressing	LittleVague	LittleSunken	smooth	0.666	0.091	否
green	hard	clear	accurate	even	sticky	0.243	0.267	否
white	hard	clear	vague	even	smooth	0.245	0.057	否
white	curl	dull	vague	even	sticky	0.343	0.099	否
green	LittleCurl	dull	LittleVague	sunken	smooth	0.639	0.161	否
white	LittleCurl	depressing	LittleVague	sunken	smooth	0.657	0.198	否
black	LittleCurl	dull	accurate	LittleSunken	sticky	0.36	0.37	否
white	curl	dull	vague	even	smooth	0.593	0.042	否
green	curl	depressing	LittleVague	LittleSunken	smooth	0.719	0.103	否
green	curl	dull	accurate	sunken	smooth	0.697	0.46	是
black	curl	depressing	accurate	sunken	smooth	0.774	0.376	是
black	curl	dull	accurate	sunken	smooth	0.634	0.264	是
green	curl	depressing	accurate	sunken	smooth	0.608	0.318	是
white	curl	dull	accurate	sunken	smooth	0.556	0.215	是
green	LittleCurl	dull	accurate	LittleSunken	sticky	0.403	0.237	是
black	LittleCurl	dull	LittleVague	LittleSunken	sticky	0.481	0.149	是
black	LittleCurl	dull	accurate	LittleSunken	smooth	0.437	0.211	是
black	LittleCurl	depressing	LittleVague	LittleSunken	smooth	0.666	0.091	否
green	hard	clear	accurate	even	sticky	0.243	0.267	否
white	hard	clear	vague	even	smooth	0.245	0.057	否
white	curl	dull	vague	even	sticky	0.343	0.099	否
green	LittleCurl	dull	LittleVague	sunken	smooth	0.639	0.161	否
white	LittleCurl	depressing	LittleVague	sunken	smooth	0.657	0.198	否
black	LittleCurl	dull	accurate	LittleSunken	sticky	0.36	0.37	否
white	curl	dull	vague	even	smooth	0.593	0.042	否
green	curl	depressing	LittleVague	LittleSunken	smooth	0.719	0.103	否

结果展示

  ​​​​​​​

 

参考 

参考自大佬文章:python 实现对率回归决策树_lazyn的博客-CSDN博客_对率回归决策树 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值