最近在学树模型,所以今天花了点时间把机器学习实战上的ID3算法敲了一遍。这个算法比较简单,不过我也是看着书敲的,因为python还不够熟悉,所以一边学点python的函数什么的。代码这边留个档,以后好回头看看。
# -*- coding: utf-8 -*-
"""
@author: 沈同学
"""
from math import log
def calcShannonEnt(dataSet):
numEntries=len(dataSet)
labelCounts={}
for featVec in dataSet:
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob*log(prob,2)
return shannonEnt
#---------------------------------------------------------------------------------
def createDataSet():
dataSet=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels=