Python编写决策树算法

main.py
import numpy as np
# pickle用于进行序列化与反序列化
# 序列化过程将文本信息转变为二进制数据流。这样就信息就容易存储在硬盘之中,
# 当需要读取文件的时候,从硬盘中读取数据,然后再将其反序列化便可以得到原始的数据。
import pickle
import os
import treePlotter

# 创建训练数据
def CreateTrainingDataset():
    X = [[0, 2, 0, 0, 'N'],
         [0, 2, 0, 1, 'N'],
         [1, 2, 0, 0, 'Y'],
         [2, 1, 0, 0, 'Y'],
         [2, 0, 1, 0, 'Y'],
         [2, 0, 1, 1, 'N'],
         [1, 0, 1, 1, 'Y'],
         [0, 1, 0, 0, 'N'],
         [0, 0, 1, 0, 'Y'],
         [2, 1, 1, 0, 'Y'],
         [0, 1, 1, 1, 'Y'],
         [1, 1, 0, 1, 'Y'],
         [1, 2, 1, 0, 'Y'],
         [2, 1, 0, 1, 'N']]
    attributeList = ["age", "income", "student", "credit_rating"]
    return X, attributeList

# 创建测试数据
def CreateTestDataset():
    X = [[0, 1, 0, 0],
         [0, 2, 1, 0],
         [2, 1, 1, 0],
         [0, 1, 1, 1],
         [1, 1, 0, 1],
         [1, 0, 1, 0],
         [2, 1, 0, 1]]

    attributeList = ["age", "income", "student", "credit_rating"]
    return X, attributeList

# 计算类别的统计信息
def GetClassInfo(Dataset):    # 例如{'Y': 10, 'N':5}
    classInfo = {}
    for item in Dataset:
        if item[-1] not in classInfo.keys():
            classInfo[item[-1]] = 1
        else:
            classInfo[item[-1]] += 1
    classInfo = dict(sorted(classInfo.items(), key=lambda x: x[1], reverse=True))

    return classInfo

# 计算最大占比类
def CalMostClass(classInfo):
    maxClass = list(classInfo.keys())[0]
    return maxClass

# 计算数据集的信息熵
def ComputeEntropy(Dataset):
    ClassInfo = GetClassInfo(Dataset)
    entropy = 0
    amount = 0
    p = []  # p[]存放的是第k个类的数据个数

    for _, val in ClassInfo.items():
        p.append(val)
        amount += val
    for pk in p:
        entropy -= (pk / amount) * np.log2(pk / amount)

    return entropy

# 计算数据集在某个属性上的的信息增益Gain(attributeList)
# Gain(D, a)
def computeAttrGainNPartition(Dataset, attributeIndex):
    gain = ComputeEntropy(Dataset)  # Initialize:初始化等于数据集D的信息熵

    # 按属性的值划分数据集子集
    LEN_DATASET = len(Dataset)
    # attributePartition = {"attrVal1": [[], [] ,.., []], ..., "attrValn": [[], [] ,.., []]}
    attributePartition = {}
    for dataItem in Dataset:
        if dataItem[attributeIndex] not in attributePartition.keys():
            attributePartition[dataItem[attributeIndex]] = []
            attributePartition[dataItem[attributeIndex]].append(dataItem)
        else:
            attributePartition[dataItem[attributeIndex]].append(dataItem)

    amount = 0
    lenth = []
    Ent = []

    # 计算信息增益
    for key, valDataSet in attributePartition.items():
        Ent.append(ComputeEntropy(valDataSet))
        lenth.append(len(valDataSet))
        amount += len(valDataSet)
    for i in range(len(Ent)):
        gain -= (lenth[i] / LEN_DATASET) * Ent[i]

    return gain, attributePartition

# 建决策树
def CreateDecisionTree(Dataset, attributeList):
    attrList = attributeList
    Tree = {}
    classInfo = GetClassInfo(Dataset)
    LEN_DATASET = len(Dataset)

    # 建立叶子节点情况1:给定的属性集为空 ---- 不能划分
    if len(attributeList) == 0:
        return CalMostClass(classInfo)

    # 建立叶子节点情况2:给定的数据集所有label都相同 ---- 无需划分
    for key, valLen in classInfo.items():
        if valLen == LEN_DATASET:
            return key
        break

    # 建立叶子节点情况3:样本在属性集上取值都相等 ---- 无法划分
    temp = Dataset[0][:-1]
    sameCnt = 0
    for dataItem in Dataset:
        if temp == dataItem[:-1]:
            sameCnt += 1
    if sameCnt == LEN_DATASET:
        return CalMostClass(classInfo)

    # 选择最佳划分属性
    theBestAttrIndex = 0
    theBestAttrGain = 0
    theBestAttrPartition = {}

    for attributeIndex in range(len(attributeList)):
        gain, attributePartition = computeAttrGainNPartition(Dataset, attributeIndex)

        if gain > theBestAttrGain:
            theBestAttrGain = gain
            theBestAttrIndex = attributeIndex
            theBestAttrPartition = attributePartition

    attrName = attributeList[theBestAttrIndex]
    # python的list对象按索引删除对象,使用的是del()函数
    del (attributeList[theBestAttrIndex])

    # # 为了方便后面建子树,将此时的attr对应的那列去除
    for key, valList in theBestAttrPartition.items():
        for index in range(len(valList)):
            temp = valList[index][:theBestAttrIndex]
            temp.extend(valList[index][theBestAttrIndex + 1:])
            valList[index] = temp


    # 根据属性的值,建立分叉节点
    Tree[attrName] = {}

    for keyAttrVal, valDataset in theBestAttrPartition.items():
        # 因为python对iterable list对象的传参是按地址传参,会改变attributeList的值
        # 所以在传attributeList参数的时候,创建一个副本,就相当于按值传递了
        subLabels = attributeList[:]
        # valDataset是已去除attr的data,attributeList是已去除attr的attributeList
        Tree[attrName][keyAttrVal] = CreateDecisionTree(valDataset, subLabels)

    return Tree

# 测试做分类
def Predict(DataSet, testArrtList, decisionTree):
    predicted_label = []

    for dataItem in DataSet:
        cur_decisionTree = decisionTree
        # 如果root就是叶子结点leaf
        if type(cur_decisionTree) == set:   # 例如:{'N'}
            node = list(cur_decisionTree)
        else:
            node = list(cur_decisionTree.keys())[0]
            # 只要temp处在attributeList,说明当前处在树枝结点(非叶子)上, 否则处在叶子结点
        while node in testArrtList:
            cur_index = testArrtList.index(node)  # 0 2
            cur_element = dataItem[cur_index]  # 0 0
            cur_decisionTree = cur_decisionTree[node][cur_element]  # {'student': {0: 'N', 1: 'Y'}} N
            if type(cur_decisionTree) == dict:
                node = list(cur_decisionTree.keys())[0]  # student
            else:
                node = cur_decisionTree
        predicted_label.append(node)
    return predicted_label

# 将模型保存起来
def SaveModel(decisionTree, filename):
    # 由于pickle是将文本序列化成binary文件,故需用wb
    f = open(filename, 'wb')
    pickle.dump(decisionTree, f)

# 读取模型
def LoadModel(filename):
    # 由于pickle读取的是binary文件,故需用rb
    f = open(filename, 'rb')
    return pickle.load(f)

if __name__ == '__main__':
    base = os.path.dirname(os.path.abspath(__file__))
    trainingDataset, attributeList = CreateTrainingDataset()
    testDataset, testArrtList = CreateTestDataset()

    path = base + "/DecisionTreeModel.txt"
    print(path)
    # 建决策树
    decisionTree = CreateDecisionTree(trainingDataset, attributeList)
    # 保存模型
    SaveModel(decisionTree, path)
    # 读取模型
    model = LoadModel(path)
    print(model)
    # 对测试数据进行预测label
    result = Predict(testDataset, testArrtList ,model)
    print(result)

    treePlotter.createPlot(model)


# 链接: https://pan.baidu.com/s/1JUl30wy8-h4cLlBLlNa0zg  密码: uds1
treePlotter.py
  • 这个程序调用的是别人的程序,遗憾的是找不到出处了
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2019/1/28 下午 09:02
# @Author  : YuXin Chen

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
                            xytext=centerPt, textcoords='axes fraction', \
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = getTreeDepth(secondDict[key]) + 1
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalw = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalw
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值