from math import log
import operator
#计算香浓熵 data每个list元素*最后一个元素*作标签的熵
def calcuShannon(data):#data二维list/array
label_count={}
for vec in data:
label=vec[-1]
label_count[label]=label_count.get(label,0)+1
# print(label_count)
num=len(data)
shannon=0
for key in label_count:
percent=label_count[key]/num
shannon-=percent*log(percent,2)
return shannon
#返回只留符合划分特征的数据集、二维List (每个元素的特征值会删去)
#data不变 返回的是其完全复制子data
def splitData(data,axis,value):#data二维list/array,axis元素中第几个值
returnList=[] #函数参数为引用 为不改变列表内容新建
for vec in data:
if vec[axis]==value:
tmpList=vec[:axis]
tmpList.extend(vec[axis+1:]) #剔除该特征
returnList.append(tmpList)
return returnList
#选择令熵减小的最大的方式最为最好的划分方式 返回划分的特征值索引
# for 每种特征
#[] --->{ 不重复值 }
#for 每种值
#计算熵 并与最好的结果比较
def bestWayOfSplit(data):#data二维list/array
bestEntroy=calcuShannon(data)
bestFeatureIndex=-1
numFeatures=len(data[0])-1 #特征数
for featureIndex in range(numFeatures):
featureSet=set([row[0]
for row in data
])
curEntroy=0.0
for value in featureSet:
subData=splitData(data,featureIndex,value)
curEntroy+=calcuShannon(subData)*len(subData)/len(data)
if curEntroy<bestEntroy: #也有可能是无用的分类 反而会令熵变大
bestEntroy=curEntroy
bestFeatureIndex=featureIndex
return bestFeatureIndex
#返回最多数的标签(例子:是否为鱼类)
def mostLabel(labels):
labelClass={}
for label in labels:
labelClass[label]=labelClass.get(label,0)+1
sortedLabelClass=sorted(labelClass.items(),key=operator.itemgetter(1),reverse=True)
return sortedLabelClass[0][0]
#多重包含的字典 2D 1D
def createTree(data,features): #features没划分一个特征 就删去一个特征 (本身的引用--变),data也删去一个属性(是复制的子数据)
labelList=[row[-1]for row in data]
if labelList.count(labelList[0])==len(labelList): #剩下同一类别
return labelList[0]
if len(data[0])==1: #无可划分的特征 剩余的data中最多的特征决定
return mostLabel(labelList)
bestFeatureIndex=bestWayOfSplit(data)
bestFeature=features[bestFeatureIndex]
del(features[bestFeatureIndex])
UniqueFeatureValue=set([row[bestFeatureIndex] for row in data])#获取该特征的不同值
myTree={bestFeature:{}} #特征的索引值是一个集合
for value in UniqueFeatureValue:
subFeatures=features[:] #这里要完全复制
myTree[bestFeature][value]=createTree(splitData(data,bestFeatureIndex,value),subFeatures)
return myTree
'''
#测试
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no sufacing', 'flippers']
return dataSet, labels
myDat, labels = createDataSet()
myTree = createTree(myDat, labels)
print(myTree)
'''
#串行化决策树
def storeTree(tree,filename):
import pickle
f=open(filename,'w')
pickle.dump(tree,f)
f.close()
def getTree(filename):
import pickle
f=open(filename)
return pickle.load(f)
'''***************************************************************************************'''
'''**********************************分类器***********************************************'''
'''***************************************************************************************'''
def Classify(tree,labels,testList):#决策树Dict 特征集List 测试集List (特征集特征与测试集的特征值顺序对应)
feature=tree.key[0]
secTree=tree[feature]
indexFeauOfLabels=labels.index[feature]
feauValue=testList[indexFeauOfLabels]
for feau in secTree.keys():
if feauValue== feau:
if type(secTree[feau]).__name__== 'Dict':
return Classify(secTree,labels,testList)
else:
return secTree[feau]
#测试
'''
f=open('lenses.txt')
dataList=f.readlines()
data=[row.split('\t') for row in dataList]
features= ['age', 'prescript', 'astigmatic', 'tearRate']
print(createTree(data,features))
'''
机器学习实战-决策树
最新推荐文章于 2022-09-28 01:04:59 发布