决策树(ID3算法)Python实现

原理:

  • 决策树其实就是一个if-else集合,通过一系列的分支将原始数据划分成最后的若干类。
  • 每个节点都是一个数据集合,对这个数据集合选择一个最优的划分特征,根据该特征来对集合进行划分,生成若干子集合。
  • 选择最优特征用到了信息学中香农熵的概念,其中,信息增益=原始熵-划分之后的熵。
  • 递归地构建决策树。

注意点:

  • 本文所用的决策树要求数据集为Python中的列表变量。
  • 每一条实例数据的最后一列元素为当前实例的类别标签。
  • 最后生成的决策树为字典变量。
  • 代码基于《机器学习实战》一书,本文只实现了决策树的构建,没有考虑剪枝,实际会存在过拟合的问题。

代码:

# coding=utf-8
# Python 2.7
# 决策树实现

from math import *
import operator

# 创建数据集
def createDataSet():
    dataSet = [[1, 1, 'yes'],
                [1, 1, 'yes'],
                [1, 0, 'no'],
                [0, 1, 'no'],
                [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

# 对于每个数据集计算熵
def calShannonEnt(dataSet) :
    numEntries = len(dataSet)
    labelsCount = {}                # 数据字典用来保存每一类元素的个数,类似map
    for featVec in dataSet:
        labelsName = featVec[-1]    # 以最后一列作为分类标记
        if (labelsName not in labelsCount.keys()):
            labelsCount[labelsName] = 0
        labelsCount[labelsName] += 1
    shannonEnt = 0.0
    for key in labelsCount:
        prob = labelsCount[key] * 1.0 / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

# 分割数据,即在数据集中选出第axis维值为value的数据,构成一个新的数据集
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]     # 新数据集每一条数据不包括axis这一列
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

# 根据信息增益选择最好的特征来划分数据
def chooseBestFeatureToSplit(dataSet):
    sumFeatures = len(dataSet[0]) - 1       # 数据总特征数
    baseEntropy = calShannonEnt(dataSet)    # 原始数据集的熵
    bestInfoGain = 0.0                      # 信息增益
    bestFeature = -1                        # 最优划分特征
    for i in range(sumFeatures):            # 枚举每一个特征
        featList = [x[i] for x in dataSet]
        uniqueValue = set(featList)         # 得到该特征所有可能值
        newEntropy = 0.0
        for value in uniqueValue:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) * 1.0 / len(dataSet)
            newEntropy += prob * calShannonEnt(subDataSet)   # 对按照该特征划分的每一类计算信息熵,
        infoGain = baseEntropy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

# 多数表决函数
def mayorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

# 创建决策树
def createTree(dataSet, labels):
    classList = [x[-1] for x in dataSet]
    if classList.count(classList[0]) == len(classList):     # 该集合只有一种类别时退出
        return classList[0]
    if len(dataSet[0]) == 1:                                # 当所有特征已经使用完,退出
        return mayorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)            # 最优特征
    bestFeatLabel = labels[bestFeat]                        # 最优特征名
    myTree = {bestFeatLabel : {}}                           # 决策树保存为一个数据字典
    del(labels[bestFeat])
    featValues = [x[bestFeat] for x in dataSet]
    uniqueFeatValues = set(featValues)
    for value in uniqueFeatValues:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)    # 递归创建
    return myTree

myData, labels = createDataSet()
featLabels = labels
print myData
tree = createTree(myData, labels)
print tree

# 利用决策树进行分类
def classify(tree, labels, testVec):
    firstStr = tree.keys()[0]
    secondDict = tree[firstStr]
    featIndex = labels.index(firstStr) # 找到当前分类特征在向量中的下标
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], labels, testVec)
            else: classLabel = secondDict[key]
    return classLabel

testVec = [1, 1]        # 测试向量
print testVec
print classify(tree, featLabels, testVec)

# 储存决策树,序列化
def storeTree(tree, fileName):
    import pickle
    fw = open(fileName, 'w')
    pickle.dump(tree, fw)
    fw.close()

# 读取决策树
def grabTree(fileName):
    import pickle
    fr = open(fileName)
    return pickle.load(fr)

storeTree(tree, 'hahaha.txt')
print grabTree('hahaha.txt')
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值