import math
from matplotlib.font_manager import FontProperties
import numpy as np
import matplotlib.pyplot as plt
import operator
import pickle
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'],
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']
]
labels = ['F1-AGE', 'F2-WORK', 'F3-HOME', 'F4-LOAN']
return dataSet, labels
def createTree(dataset, labels, featLabels):
classList = [item[-1] for item in dataset]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataset[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataset)
bestFeatLabel = labels[bestFeat]
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel : {}}
del labels[bestFeat]
featValue = [item[bestFeat] for item in dataset]
uniqueVals = set(featValue)
for value in uniqueVals:
sublabels = labels[ : ]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset, bestFeat, value), sublabels, featLabels)
return myTree
def majorityCnt(dataset):
yes_cnt, no_cnt = 0, 0
for i in dataset:
if i == 'yes':
yes_cnt += 1
else:
no_cnt += 1
return 'yes' if yes_cnt > no_cnt else 'no'
def chooseBestFeatureToSplit(dataset):
numFeatures = len(dataset[0]) - 1
baseEntropy = calcShannonEnt(dataset)
bestInfoGain = 0
bestFeature = -1
for i in range(numFeatures):
featList = [item[i] for item in dataset]
uniqueVals = set(featList)
newEntropy = 0
for val in uniqueVals:
subDataSet = splitDataSet(dataset, i, val)
prop = len(subDataSet) / float(len(dataset))
newEntropy += prop * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def splitDataSet(dataset, axis, val):
retDataSet = []
for featVec in dataset:
if featVec[axis] == val:
reduceFeatVec = featVec[ : axis]
reduceFeatVec += featVec[axis + 1 : ]
retDataSet.append(reduceFeatVec)
return retDataSet
def calcShannonEnt(dataset):
numexamples = len(dataset)
labelCounts = {}
for featVec in dataset:
currentlabel = featVec[-1]
if currentlabel not in labelCounts.keys():
labelCounts[currentlabel] = 0
labelCounts[currentlabel] += 1
shannonEnt = 0
for key in labelCounts:
prop = float(labelCounts[key]) / numexamples
shannonEnt -= prop * math.log(prop, 2)
return shannonEnt
if __name__ == '__main__':
dataset, labels = createDataSet()
featLabels = []
myTree = createTree(dataset, labels, featLabels)
print(myTree)
print(featLabels)
机器学习,一颗简单的决策树代码实现
于 2023-07-01 16:00:48 首次发布