相关知识
- 决策树是一种分类方法,通过不断的选取最优特征来进行树的建立
- 香农熵:度量数据集的无序(混乱)程度
H
(
x
)
=
−
p
∗
l
o
g
(
p
)
H(x) = -p*log(p)
H(x)=−p∗log(p),其中p表示选择该分类的概率 - 信息增益(information gain):
g
(
D
,
A
)
=
H
(
D
)
−
H
(
D
∣
A
)
g(D,A) = H(D) - H(D|A)
g(D,A)=H(D)−H(D∣A);暨表示在给定A类别下,数据集合不确定性(混乱性)减小的程度。
思想
- 对于给定数据集,决策树算法从所有特征中选择当前情况下最好的特征(暨信息增益最大,对该特征划分会使数据集合的不确定性大大降低),然后对数据集依照该特征的所有属性值进行此次划分。划分后将该特征去除,在剩余数据集中做如上相同的操作来递归的建立决策树,直到所有的特征均被划分或每个分支下的实例都属于相同的分类。
- 若数据集已经处理了所有特征但类标签依然不唯一,则采用多数投票的方式进行归类。
代码实现
import operator
from math import log
def createData():
dataset = [
[1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']
]
labels = ['no surfacing', 'flippers']
return dataset, labels
def calcShannonEnt(dataset):
numEntries = len(dataset)
labelsCount = {}
for featureVec in dataset:
currentLabel = featureVec[-1]
if currentLabel not in labelsCount.keys():
labelsCount[currentLabel] = 0
labelsCount[currentLabel] += 1
Entropy = 0.0
for key in labelsCount.keys():
p = float(labelsCount[key]) / numEntries
Entropy -= p * log(p,2)
return Entropy
def spiltDataset(dataset, axis, value):
extractData = []
for featureVec in dataset:
if featureVec[axis] == value:
reducedFeatVec = featureVec[:axis]
reducedFeatVec.extend(featureVec[axis+1:])
extractData.append(reducedFeatVec)
return extractData
- 对当前数据集计算每个特征的信息增益来找出最好的划分特征
def chooseBestFeatureToSpilt(dataset):
numFeatures = len(dataset[0]) - 1
baseEntropy = calcShannonEnt(dataset)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataset]
uniqueFeature = set(featList)
newEntropy = 0.0
for value in uniqueFeature:
subDataset = spiltDataset(dataset, i, value)
p = len(subDataset) / float(len(dataset))
newEntropy += p * calcShannonEnt(subDataset)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majority(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataset, label):
classList = [example[-1] for example in dataset]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataset[0]) == 1:
return majority(classList)
bestFeature = chooseBestFeatureToSpilt(dataset)
bestFeatLabel = labels[bestFeature]
theTree = {bestFeatLabel:{}}
del(labels[bestFeature])
featValues = [example[bestFeature] for example in dataset]
uniqueFeatValues = set(featValues)
for value in uniqueFeatValues:
sublabels = labels[:]
theTree[bestFeatLabel][value] = createTree(spiltDataset(dataset, bestFeature, value), sublabels)
return theTree
if __name__ == '__main__':
datasets, labels = createData()
tree = createTree(datasets, labels)
print(tree)