程序:
from math import *
import operator
def calShannonEnt(dataSet): #熵的计算函数,熵的计算只与标签值有关
numEntries = len(dataSet) #len(dataSet)计算dataSet矩阵的行数,和dataSet.shape[0]功能一样
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] #取每组数据最后一个标签数据
#if currentLabel not in labelCounts.keys():
# labelCounts[currentLabel] = 0
#labelCounts[currentLabel] += 1
labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
shannonEnt = 0.0
for key in labelCounts:
prob = float( labelCounts[key] ) / numEntries #每一种标签出现的概率
shannonEnt -= prob * log(prob, 2) #熵的计算
return shannonEnt
def createDataSet(): #海洋生物数据
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
] #列表之中套列表
labels = ['no surfacing', 'flippers']
return dataSet, labels
def splitDataSet(dataSet, axis, value):
retDataSet = [] #创建一个新列表,准备提取数据
for featVec in dataSet: #dataSet为[ [], [], [], ....]形式
if featVec[axis] == value: #每个列表axis位置的值符合规定的特征
reducedFeatVec = featVec[:axis] #reducedFeatVec先取featVec[0:axis]
reducedFeatVec.extend( featVec[axis+1:] ) #reducedFeatVec再补上featVec[axis+1:]
retDataSet.append(reducedFeatVec) #retDataSet空列表装入reducedFeatVec,相当于没有改变原数据
return retDataSet #retDataSet为[ [], [], [], ....]形式,里面的列表为符合规定特征的列表,且没有axis位置的数据
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #除去最后一个标签值外,特征的个数
baseEntropy = calShannonEnt(dataSet) #原始香农熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [ example[i] for example in dataSet] #dataSet中每个列表的第1个特征值提出来
uniqueVals = set(featList) #集合,没有重复元素,即可以看出每个小列表第一个特征有多少种不同的
newEntropy = 0.0
for value in uniqueVals: #uniqueVals为第一个特征所有不重复的特征值
subDataSet = splitDataSet(dataSet, i, value) #subDataSet为[ [], [], [], ....]形式,[]为去除第一个特征的,符合规定好的第一个特征值的列表
prob = len(subDataSet) / float( len(dataSet) ) #计算此特征中的一个分类的概率
newEntropy += prob * calShannonEnt(subDataSet) #该特征下香农熵的期望值
infoGain = baseEntropy - newEntropy #原始香农熵减去第一个特征值的香农熵,熵的减少或者说无序度的减少,就是信息增益
if (infoGain > bestInfoGain): #原始熵大于新熵,无序度降低,infoGain信息增益(无序度减少的数量)
bestInfoGain = infoGain #下一次循环若信息增益比前一次还要大,则......
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {}
for vote in classList:
classCount[vote] = classCount.get(vote, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels): #labels为数据所有特征的标签,比如有无蹼,离水能否生存
classList = [ example[-1] for example in dataSet ] #提取每组数据标签值
if classList.count( classList[0] ) == len(classList):
return classList[0] #递归终止条件1:所有标签值全部一致,返回该标签值
if len( dataSet[0] ) == 1: #dataSet[0]取矩阵第一行,dataSet[0][0]取矩阵第一行第一列元素
return majorityCnt(classList) #递归终止条件2:使用完所有特征,则返回最后出现次数最多的那个标签
#上述两个终止条件相当于判断当前是判断节点还是叶子节点,也就是: 是方框还是圆框
bestFeat = chooseBestFeatureToSplit(dataSet) #以上两个终止条件都不满足,开始选择最优特征划分,已经有了一个方框,准备往方框中写入判断问题
bestFeatLabel = labels[bestFeat] #最优特征的名字赋给bestFeatLabel
myTree = { bestFeatLabel:{} } #形成字典{ 判断1:{} }
del( labels[bestFeat] ) #用过了该特征,将该特征从所有特征列表中删除
featValues = [ example[bestFeat] for example in dataSet ] #得出该特征下一共有几种分类
uniqueVals = set(featValues) #去除重复
for value in uniqueVals:
subLabels = labels[:] #subLabels为去除已经用过的特征后剩下的特征列表
splitdata = splitDataSet(dataSet, bestFeat, value)
myTree[bestFeatLabel][value] = createTree( splitdata, subLabels )
#递归,
return myTree
学习要点:
1.列表.extend()和.append()的区别
>>> a = [1, 2, 3]
>>> a.extend([1, 2])
>>> print(a)
[1, 2, 3, 1, 2]
>>> a.append([1, 2])
>>> print(a)
[1, 2, 3, 1, 2, [1, 2]]
>>> a.extend(4) #括号内只能是一个列表
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: 'int' object is not iterable
>>> a.extend([4])
>>> print(a)
[1, 2, 3, 1, 2, [1, 2], 4]
>>> a.extend([2,3],[4,5])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: extend() takes exactly one argument (2 given)
>>> a.append([2,3],[4,5]) #可以是一个元素,一个列表,一个元组,但不允许出现2个
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: append() takes exactly one argument (2 given)
2.一个注意点:
在定义的函数中对外部存在的列表进行修改,调用函数后将会确确实实的改变列表中的数据,即改变了内存中的数据:
a = [1, 2, 3, 4, 5]
def change(b):
b.append(6)
change(a)
print(a) #结果:[1, 2, 3, 4, 5, 6]
3.关于列表中的数是str还是int:
若是由split()分割而来的数字肯定是str
直接输入的为int,如:
a = [1, 2, 3, 4, 5]
b = a[0]
print(type(b)) #int
b = a[0:1] #b = [1],为列表,这种取法生成的都是列表
4.列表推导:
[表达式 for 变量 in 列表] 或者 [表达式 for 变量 in 列表 if 条件]
>>> a = [1, 2, 3, 4 ,5]
>>> b = [ x**2 for x in a ]
>>> b
[1, 4, 9, 16, 25]
>>> c = [x**2 for x in a if x>2]
>>> c
[9, 16, 25]
>>> d = dict( [(x, 2*x) for x in a] )
>>> d
{1: 2, 2: 4, 3: 6, 4: 8, 5: 10}
>>> e = [ (x, y) for x in a if x>3 for y in a if y<3 ]
>>> e
[(4, 1), (4, 2), (5, 1), (5, 2)]
5.del的用法:
>>> a = [1,2,3,4]
>>> del(a[0]) #删除a的第一个元素
>>> a
[2, 3, 4]
>>> del a
>>> a #删除a这个变量
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name 'a' is not defined