# 决策树

12人阅读 评论(0)

def calculateEntropy(dataSet):
numberEntries = len(dataSet)
labelCounts = {}
for featVector in dataSet:
currentLabel = featVector[-1]  # 取每行数据的类别 --> 最后一列
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numberEntries  # 计算每个类别出现的概率
shannonEnt -= prob * log(prob,2)              # 计算香农熵
return shannonEnt

def createDataSet():
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
labels = ['no surfacing', 'flippers']
return  dataSet, labels

def splitDataSet(dataSet,axis,value): # dataSet给定数据集 axis给定特征(索引)  axis给定特征的值
returnDataSet = []
for featureVec in dataSet:
if featureVec[axis] == value:
reducedFeatVec = featureVec[:axis]
reducedFeatVec.extend(featureVec[axis + 1:])
returnDataSet.append(reducedFeatVec)
return returnDataSet

def chooseBestFeatureToSplit(dataSet):
numberFeatures = len(dataSet[0]) - 1
baseEntropy = calculateEntropy(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numberFeatures):  # 按照第i个特征进行划分的情况
# 使用列表推导式创建新的列表  保存第i个特征所有的特征值
featureList = [example[i] for example in dataSet]
uniqueVals = set(featureList) # 去除重复值
newEntropy = 0.0
for value in uniqueVals:  #第i个特征所有的特征值
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/len(dataSet)
newEntropy += prob * calculateEntropy(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature

def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
return sortedClassCount[0][0]

# 创建决策树
def createTree(dataSet,labels):  # labels 保存所有的特征
classList = [example[-1] for example in dataSet]  # 保存所有类别数据
if classList.count(classList[0]) == len(classList):
return classList[0]             # 类别完全相同则停止划分
if len(dataSet[0]) == 1:
return  majorityCnt(classList)  # 遍历完所有特征时返回出现次数最多的类别
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])   # 将已经用过的特征从labels中删除
featValues = [example[bestFeat] for example in dataSet]  # 保存此次用于分类的特征的所有值
uniqueVals = set(featValues)  # 去除重复值
for value in uniqueVals:      # 遍历所有类别
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return  myTree

0
0

* 以上用户言论只代表其个人观点，不代表CSDN网站的观点或立场
个人资料
• 访问：44次
• 积分：40
• 等级：
• 排名：千里之外
• 原创：4篇
• 转载：0篇
• 译文：0篇
• 评论：0条
文章分类
文章存档