统计学习方法第五章决策树C4.5算法代码实践(其实相对于ID3算法只是寻找最优划分属性的标准发生了改变,ID3为信息增益而C4.5为信息增益率即在ID3的信息增益除以数据集关于该特征划分的熵)
from numpy import *
import math
def loadDataSet(): # 本书例题的数据集
dataset = [['青年', '否', '否', '一般', '否'],
['青年', '否', '否', '好', '否'],
['青年', '是', '否', '好', '是'],
['青年', '是', '是', '一般', '是'],
['青年', '否', '否', '一般', '否'],
['中年', '否', '否', '一般', '否'],
['中年', '否', '否', '好', '否'],
['中年', '是', '是', '好', '是'],
['中年', '否', '是', '非常好', '是'],
['中年', '否', '是', '非常好', '是'],
['老年', '否', '是', '非常好', '是'],
['老年', '否', '是', '好', '是'],
['老年', '是', '否', '好', '是'],
['老年', '是', '否', '非常好', '是'],
['老年', '否', '否', '一般', '否']]
label = ['年龄', '有工作', '有自己的房子', '信贷情况']
return dataset, label
def calculateEntopy(dataSet): # 计算训练集的信息熵
resultList = [data[-1] for data in dataSet]
lenOfResult = float(len(resultList))
uniqueLabels = set(resultList)
curEntropy = 0.0
for label in uniqueLabels:
prob = resultList.count(label) / lenOfResult
curEntropy -= prob * math.log(prob, 2)
return curEntropy
def splitDataSet(dataSet,i,value):
returnDataSet=[]
for data in dataSet:
if data[i] == value:
returnList=data[:i]
returnList.extend(data[i+1::])
returnDataSet.append(returnList)
return returnDataSet
def calculateGainRate(dataSet, i): # 计算根据某个特征进行划分后的信息增益
m = float(len(dataSet))
baseEntropy = calculateEntopy(dataSet)
listOfValue = [data[i] for data in dataSet]
uniqueValueList = set(listOfValue)
newEntropy = 0.0;infoGainRate=0.0
for value in uniqueValueList:
returnDataSet=splitDataSet(dataSet,i,value)
lenReturnDataSet=len(returnDataSet)
prob = lenReturnDataSet / m
newEntropy += (prob) * calculateEntopy(returnDataSet)
infoGainRate -= prob*math.log(prob,2) #这里相对于ID3算法相当于多了一个分母,分母为训练数据集关于特征A的的值的熵
#print("信息后验率",infoGainRate)
return (baseEntropy - newEntropy)/infoGainRate
def chooseBestValueToSplit(dataSet): # 通过本函数寻找能获得最大增益的属性
m, n = shape(dataSet)
bestInfoGainRate = 0
bestValue = 0.0
for i in range(n - 1):
curInfoGainRate = calculateGainRate(dataSet, i)
#print(curInfoGainRate)
if curInfoGainRate > bestInfoGainRate:
bestInfoGainRate = curInfoGainRate
bestValue = i
return bestValue
def maxResult(resultList):
calcNumDict=dict([(resultList.count(result),result) for result in resultList])
return calcNumDict[max(calcNumDict.keys())]
def createTree(dataSet,labels): #首先判别是否需要继续划分
resultList=[data[-1] for data in dataSet]
if len(dataSet[0])==1:
return maxResult(resultList)
if resultList.count(resultList[0]) == len(resultList):
return resultList[0]
bestValue=chooseBestValueToSplit(dataSet)
bestLabel=labels[bestValue]
tree={bestLabel:{}}
del(labels[bestValue])
uniqueValue=set([data[bestValue] for data in dataSet])
for value in uniqueValue:
returnDataSet=splitDataSet(dataSet,bestValue,value)
subLabels=labels[:]
subTree=createTree(returnDataSet,subLabels)
tree[bestLabel][value]=subTree
return tree
dataSet,label=loadDataSet()
print(createTree(dataSet,label))
结果为:(基于本书上的数据集好像decision tree 并没有发生变化)
{'有自己的房子': {'否': {'有工作': {'否': '否', '是': '是'}}, '是': '是'}}
made by zcl at CUMT
I know I can because I have a heart that beats