(一)计算给定数据集的香农熵(个人理解为计算给定信息集纯度的一种数学计算指标):
from math import log
def calcShannonEnt(dataSet):#calculata shannonEnt
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:#将当前键值加入字典并记录类别出现的次数
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:#计算香农熵
prob = float(labelCounts[key])/numEntries#使用所有类标签的发生频率计算类别出现的概率
shannonEnt -= prob*log(prob,2)#得到香农熵
return shannonEnt
测试代码:
def createDataSet():
dataSet = [[1,1,'maybe'],
[1, 1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels
#放在另一个py文件内的test:
import CreateDataSet
import trees
myDat,labels=CreateDataSet.createDataSet()
print(myDat)
print(trees.calcShannonEnt(myDat))
(二)划分数据集:
需要的python基础:也可看我整理出来的文章
前期准备(人为划分,给定属性以及相应的值,作为后面函数的调用)
def splitDataSet(dataSet,axis,value):#将属性axis中满足值为value的数据划分出来
retDataSet = []#Python在函数中传递的是列表的引用,在函数内部对列表对象的修改将会影响该列表的整个生命周期。为了消除这个不良影响,需要在函数的开始声明一个新列表对象。
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]#通过以下两步可得到满足所给条件的除去属性(axis+1发挥的作用)axis的数据
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
测试代码:
import CreateDataSet
import trees
myDat,labels=CreateDataSet.createDataSet()
print(trees.splitDataSet(myDat,1,1))
正式划分:(利用信息增益得到所有属性中最适合划分的一个)
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0;bestFeature=-1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]#列表解析(推导),得到dataSet中的第i个属性的所有取值eg:(1,1,1,0,0)
uniqueVals = set (featList)#通过集合中元素唯一的特性,将得到的featList中的重复元素变唯一eg:(1,0)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
test:
import CreateDataSet
import trees
myDat,labels=CreateDataSet.createDataSet()
print(trees.chooseBestFeatureToSplit(myDat))
“`