机器学习实战:决策树

from math import log
import csv
import matplotlib.pyplot as plt
import operator
import numpy as np
import pandas as pd

decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')


def pltNode(txt, centerPt, parentPt, type):
    createPlt.ax1.annotate(txt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
                           va='center', ha='center', bbox=type, arrowprops=arrow_args)


def createPlt():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlt.ax1 = plt.subplot(111, frameon=False)
    pltNode('a decision node ', (0.9, 0.1), (0.1, 0.5), decisionNode)
    pltNode('a leaf node ', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()


def getNumLeafs(myTree):
    n = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for i in secondDict.keys():
        if type(secondDict[i]).__name__ == 'dict':
            n += getNumLeafs(secondDict[i])
        else:
            n += 1
    return n


def getDepth(myTree):
    n = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for i in secondDict.keys():
        if type(secondDict[i]).__name__ == 'dict':
            d = 1 + getDepth(secondDict[i])
        else:
            d = 1
        print(d)
        if d > n:
            n = d
    return n


# 计算香农熵
def calShannonEnt(dataSet):
    num = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        # print('featVec = ', featVec)
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1

        # print('labelCounts = ', labelCounts)
    shannonEnt = 0.0
    for key in labelCounts:
        prob = labelCounts[key] / num
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


# 划分数据集,axis 第几列,value 该列的特征值
def splitDataSet(dateSet, axis, value):
    retDateSet = []
    for featVec in dateSet:
        # print('split  featVec = ', featVec)
        # 按列划分,每个数据如果该列的特征值等于 value ,则把该数据添加到数据子集里,但不包含该列
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            # print('reduceFeatVec = ', reduceFeatVec)
            reduceFeatVec.extend(featVec[axis + 1:])
            # print('变化完 = ', reduceFeatVec)
            retDateSet.append(reduceFeatVec)
            # print('retDateSet = ', retDateSet)
    return retDateSet


# 创建数据集
def createDataSet():
    dateSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dateSet, labels


# 选择最好的划分方式
def bestSplit(dataSet):
    n = len(dataSet[0]) - 1  # 特征的个数
    baseEnt = calShannonEnt(dataSet)
    bestGain = 0.0
    bestFeature = -1
    # 按列循环判断
    for i in range(n):
        # print('i = ', i)
        feat = [a[i] for a in dataSet]  # 列表生成器,该列的所有值
        uniqueVals = set(feat)  # 特征值的集合
        newEnt = 0.0
        # 按每个特征值都划分一遍子集
        for value in uniqueVals:
            # print('value = ', value)
            subDataSet = splitDataSet(dataSet, i, value)
            # print('subDataSet = ', subDataSet)
            # 计算信息熵
            prob = len(subDataSet) / len(dataSet)
            newEnt += prob * calShannonEnt(subDataSet)
            # print('newEnt = ', newEnt)
        # print('第 ', i, ' 个特征的 ', 'newEnt = ', newEnt)
        gain = baseEnt - newEnt  # 计算增益
        if gain > bestGain:
            bestGain = gain
            bestFeature = i
    return bestFeature


def vote(classList):
    classCount = {}
    for i in classList:
        if i not in classCount.keys():
            classCount[i] = 0
        classCount[i] += 1
    sortClass = sorted(classCount.items(), key=lambda x: x[1], reverse=True)
    return sortClass[0][1]


def createTree(dataSet, labels):
    classList = [a[-1] for a in dataSet]
    if classList.count(classList[0]) == len(classList):
        # print('都是同类')
        return classList[0]
    if len(dataSet[0]) == 1:
        # print('特征用完')
        return vote(classList)

    bestFeat = bestSplit(dataSet)
    bestLabel = labels[bestFeat]

    myTree = {bestLabel: {}}
    del (labels[bestFeat])

    feat = [e[bestFeat] for e in dataSet]
    uniqueValue = set(feat)
    for v in uniqueValue:
        # print(v)
        subLabels = labels[:]
        # print(subLabels)
        myTree[bestLabel][v] = createTree(splitDataSet(dataSet, bestFeat, v), subLabels)
    # print(myTree)
    return myTree


if __name__ == '__main__':
    myData, label = createDataSet()
    # shannonEnt = calShannonEnt(myData)
    # bestFeat = bestSplit(myData)

    myTree = createTree(myData, label)
    print(myTree)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值