# 【机器学习】 - 决策树（西瓜数据集）

2 篇文章 0 订阅

#利用决策树算法，对mnist数据集进行测试
import numpy as np

#计算熵
def calcEntropy(dataSet):
mD = len(dataSet)
dataLabelList = [x[-1] for x in dataSet]
dataLabelSet = set(dataLabelList)
ent = 0
for label in dataLabelSet:
mDv = dataLabelList.count(label)
prop = float(mDv) / mD
ent = ent - prop * np.math.log(prop, 2)

return ent

# # 拆分数据集
# # index - 要拆分的特征的下标
# # feature - 要拆分的特征
# # 返回值 - dataSet中index所在特征为feature，且去掉index一列的集合
def splitDataSet(dataSet, index, feature):
splitedDataSet = []
mD = len(dataSet)
for data in dataSet:
if(data[index] == feature):
sliceTmp = data[:index]
sliceTmp.extend(data[index + 1:])
splitedDataSet.append(sliceTmp)
return splitedDataSet

#根据信息增益 - 选择最好的特征
# 返回值 - 最好的特征的下标
def chooseBestFeature(dataSet):
entD = calcEntropy(dataSet)
mD = len(dataSet)
featureNumber = len(dataSet[0]) - 1
maxGain = -100
maxIndex = -1
for i in range(featureNumber):
entDCopy = entD
featureI = [x[i] for x in dataSet]
featureSet = set(featureI)
for feature in featureSet:
splitedDataSet = splitDataSet(dataSet, i, feature)  # 拆分数据集
mDv = len(splitedDataSet)
entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)
if(maxIndex == -1):
maxGain = entDCopy
maxIndex = i
elif(maxGain < entDCopy):
maxGain = entDCopy
maxIndex = i

return maxIndex

# 寻找最多的，作为标签
def mainLabel(labelList):
labelRec = labelList[0]
maxLabelCount = -1
labelSet = set(labelList)
for label in labelSet:
if(labelList.count(label) > maxLabelCount):
maxLabelCount = labelList.count(label)
labelRec = label
return labelRec

#生成树
def createDecisionTree(dataSet, featureNames):
labelList = [x[-1] for x in dataSet]
if(len(dataSet[0]) == 1): #没有可划分的属性了
return mainLabel(labelList)  #选出最多的label作为该数据集的标签
elif(labelList.count(labelList[0]) == len(labelList)): # 全部都属于同一个Label
return labelList[0]

bestFeatureIndex = chooseBestFeature(dataSet)
bestFeatureName = featureNames.pop(bestFeatureIndex)
myTree = {bestFeatureName: {}}
featureList = [x[bestFeatureIndex] for x in dataSet]
featureSet = set(featureList)
for feature in featureSet:
featureNamesNext = featureNames[:]
splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
myTree[bestFeatureName][feature] = createDecisionTree(splitedDataSet, featureNamesNext)
return myTree

#读取西瓜数据集2.0
ifile = open("周志华_西瓜数据集2.txt")
labels = (featureName.split(' ')[0]).split(',')
dataSet = []
for line in lines:
tmp = line.split('\n')[0]
tmp = tmp.split(',')
dataSet.append(tmp)

return dataSet, labels

def main():
#读取数据
print(createDecisionTree(dataSet, featureNames))

if __name__ == "__main__":
main()


{‘纹理’: {‘模糊’: ‘否’, ‘清晰’: {‘根蒂’: {‘稍蜷’: {‘色泽’: {‘乌黑’: {‘触感’: {‘硬滑’: ‘是’, ‘软粘’: ‘否’}}, ‘青绿’: ‘是’}}, ‘蜷缩’: ‘是’, ‘硬挺’: ‘否’}}, ‘稍糊’: {‘触感’: {‘硬滑’: ‘否’, ‘软粘’: ‘是’}}}}

#生成决策树
# featureNamesSet 是featureNames取值的集合
# labelListParent 是父节点的标签列表
def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):
labelList = [x[-1] for x in dataSet]
if(len(dataSet) == 0):
return mainLabel(labelListParent)
elif(len(dataSet[0]) == 1): #没有可划分的属性了
return mainLabel(labelList)  #选出最多的label作为该数据集的标签
elif(labelList.count(labelList[0]) == len(labelList)): # 全部都属于同一个Label
return labelList[0]

bestFeatureIndex = chooseBestFeature(dataSet)
bestFeatureName = featureNames.pop(bestFeatureIndex)
myTree = {bestFeatureName: {}}
featureList = featureNamesSet.pop(bestFeatureIndex)
featureSet = set(featureList)
for feature in featureSet:
featureNamesNext = featureNames[:]
featureNamesSetNext = featureNamesSet[:][:]
splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)
return myTree

#读取西瓜数据集2.0
ifile = open("周志华_西瓜数据集2.txt")
featureNames = (featureName.split(' ')[0]).split(',')
dataSet = []
for line in lines:
tmp = line.split('\n')[0]
tmp = tmp.split(',')
dataSet.append(tmp)

#获取featureNamesSet
featureNamesSet = []
for i in range(len(dataSet[0]) - 1):
col = [x[i] for x in dataSet]
colSet = set(col)
featureNamesSet.append(list(colSet))

return dataSet, featureNames, featureNamesSet


• 12
点赞
• 71
收藏
觉得还不错? 一键收藏
• 打赏
• 8
评论
03-30
07-19 636
09-20 391
07-01 1万+
12-28 716
02-12 1003
10-06 4215
07-14 3745
04-12
05-08 1万+
10-31 1286

### “相关推荐”对你有帮助么？

• 非常没帮助
• 没帮助
• 一般
• 有帮助
• 非常有帮助

ylemfei

¥1 ¥2 ¥4 ¥6 ¥10 ¥20

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、付费专栏及课程。