机器学习之决策树,ID3算法

本文介绍了决策树构建过程,详细解析了ID3算法中的信息增益计算方法,并通过实例展示了如何使用Python实现决策树的构建及分类。

首先来看下决策树的基本定义,维基百科:

http://zh.wikipedia.org/wiki/%E5%86%B3%E7%AD%96%E6%A0%91

为了构建决策树,首先需要决定数据集中那个属性来划分数据集,去构建决策树。

分类能力最好的属性被选作树的根节点,


这里主要用于学习布尔函数的ID3算法。

首先介绍下ID3算法的概要


看上去有点看不懂,先放这吧,下面仔细来论述下

为了选择最好的划分,这里引入信息增益的概念,

信息增益:用来衡量给定的属性区分训练样例的能力

ID3算法在增长树的每一步使用这个增益标准从候选属性中选择属性

熵:(entropy)刻画了任意样例的纯度,给定关于某个目标概念的正反样例集S则entropy的计算公式是:

E(S)=-P(+)logP(+) - P(-)logP(-)

举一个例子

一个集合中包含[9+,-5],则E(S)=-(9/14)log(9/14) - (5/14)log(5/14)


信息增益(information gain)Gain(S,A) = E(S) - sum((|Sv|/|S|) *E(Sv))


举一个例子,还是拿上面的说,加入14个样本中假定正例中有6个是YES属性,反例中有2个是YES

Gain(S,YES) = E(S) - (8/14)E(S=YES) - (6/14)E(S=NO)

def calShanonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet: #the the number of unique elements and their occurance
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2) #log base 2
    return shannonEnt

python代码如上

其中

dataset类似这样[1, 1, 'yes']



featVec[-1] = yes or no

首先计算所有数据集中出现的yes和no的次数,然后计算熵

熵越大,数据越混乱

现在来创建一个数据集来演示程序怎么使用

def createDataSet():
	dataset =  [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
	labels = ['no surfacing','flippers']
	return dataset,labels



现在计算得到了熵,下一步就是要分割数据集获得信息增益

分割数据

def splitDataSet(dataset,axis,value):
	retDataSet = []
	for featVec in dataset:
		if featVec[axis] == value:
			reducedFeatVec = featVec[:axis]
			reducedFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reducedFeatVec)
		
	return retDataSet


这个函数的作用是按照索引axis上的值value来分割数据,返回的是axis上值为value的数据集,不包括axis这一维

下一步开始选择最优属性来分割了

def chooseBestFeatureToSplit(dataset):
	numFeatures = len(dataset[0]) - 1
	baseEntropy = calShanonEnt(dataset)
	bestInfoGain = 0.0
	bestFeature = -1

	for i in range(numFeatures):
		featList = [example[i] for example in dataset]
		 
		uniqueVals = set(featList)
		 
		newEntropy = 0
		for value in uniqueVals:
			subDataSet = splitDataSet(dataset,i,value)
			prob = len(subDataSet)/float(len(dataset))
			newEntropy += prob * calShanonEnt(subDataSet)
		infoGain = baseEntropy - newEntropy
		if infoGain > bestInfoGain:
			bestInfoGain = infoGain
			bestFeature = i
			pass
			pass
		pass
	return bestFeature
	pass

这一步是计算信息增益

下一步就是递归构建树了

def  createTree(dataset,labels):
	classList = [example[-1] for example in dataset]
	if classList.count(classList[0]) == len(classList):
		return classList[0]
	if len(dataset[0])==1:
		return majorityCnt(classList)
	bestFeat = chooseBestFeatureToSplit(dataset)
	bestFeatLabel = labels[bestFeat]
	myTree = {bestFeatLabel:{}}
	del(labels[bestFeat])
	featValues = [example[bestFeat] for example in dataset]
	uniqueVals = set(featValues)
	for value in uniqueVals:
		sublabels = labels[:]
		myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeat,value),sublabels)
		pass
	return myTree
	pass

这个函数有两个输入,一个是数据集,一个是类别

递归停止条件:

1.所有的类别是相同的

2.无属性可分了

针对第二种情况需要单独处理

def majorityCnt(classList):
	classCount = {}
	for vote  in classList:
		if vote not in classCount.keys():
			classCount[vote] = 0
			classCount[vote] += 1

			 
	sortedClassCount = sorted(classCount.iteritems,key=operator.itemgetter(1),reverse=True)
		 
	return sortedClassCount[0][0]
	pass

这里暂时没有看懂。。。。。。。。


这里创建了一个字典来存放树,bestFeat用来存放每一次选择的最优属性的索引,value存放值,然后来分割数据集,最后创建树。


得到树了就可以用来分类了。

def classify(inputTree,featLabels,testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel

就是拿新的数据沿着树走一遍,走到叶子节点的时候得到的分类就是结果


测试:

def test():
	myDat,labels = createDataSet()
	#print calShanonEnt(myDat)
	#print splitDataSet(myDat,0,1)
	#print chooseBestFeatureToSplit(myDat)
	trees = createTree(myDat,labels)
	labels = ['no surfacing','flippers']
	print classify(trees,labels,[1,0])

未完待续。。。。。下一步需要将树画出来


评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值