main.py
import numpy as np
import pickle
import os
import treePlotter
def CreateTrainingDataset():
X = [[0, 2, 0, 0, 'N'],
[0, 2, 0, 1, 'N'],
[1, 2, 0, 0, 'Y'],
[2, 1, 0, 0, 'Y'],
[2, 0, 1, 0, 'Y'],
[2, 0, 1, 1, 'N'],
[1, 0, 1, 1, 'Y'],
[0, 1, 0, 0, 'N'],
[0, 0, 1, 0, 'Y'],
[2, 1, 1, 0, 'Y'],
[0, 1, 1, 1, 'Y'],
[1, 1, 0, 1, 'Y'],
[1, 2, 1, 0, 'Y'],
[2, 1, 0, 1, 'N']]
attributeList = ["age", "income", "student", "credit_rating"]
return X, attributeList
def CreateTestDataset():
X = [[0, 1, 0, 0],
[0, 2, 1, 0],
[2, 1, 1, 0],
[0, 1, 1, 1],
[1, 1, 0, 1],
[1, 0, 1, 0],
[2, 1, 0, 1]]
attributeList = ["age", "income", "student", "credit_rating"]
return X, attributeList
def GetClassInfo(Dataset):
classInfo = {}
for item in Dataset:
if item[-1] not in classInfo.keys():
classInfo[item[-1]] = 1
else:
classInfo[item[-1]] += 1
classInfo = dict(sorted(classInfo.items(), key=lambda x: x[1], reverse=True))
return classInfo
def CalMostClass(classInfo):
maxClass = list(classInfo.keys())[0]
return maxClass
def ComputeEntropy(Dataset):
ClassInfo = GetClassInfo(Dataset)
entropy = 0
amount = 0
p = []
for _, val in ClassInfo.items():
p.append(val)
amount += val
for pk in p:
entropy -= (pk / amount) * np.log2(pk / amount)
return entropy
def computeAttrGainNPartition(Dataset, attributeIndex):
gain = ComputeEntropy(Dataset)
LEN_DATASET = len(Dataset)
attributePartition = {}
for dataItem in Dataset:
if dataItem[attributeIndex] not in attributePartition.keys():
attributePartition[dataItem[attributeIndex]] = []
attributePartition[dataItem[attributeIndex]].append(dataItem)
else:
attributePartition[dataItem[attributeIndex]].append(dataItem)
amount = 0
lenth = []
Ent = []
for key, valDataSet in attributePartition.items():
Ent.append(ComputeEntropy(valDataSet))
lenth.append(len(valDataSet))
amount += len(valDataSet)
for i in range(len(Ent)):
gain -= (lenth[i] / LEN_DATASET) * Ent[i]
return gain, attributePartition
def CreateDecisionTree(Dataset, attributeList):
attrList = attributeList
Tree = {}
classInfo = GetClassInfo(Dataset)
LEN_DATASET = len(Dataset)
if len(attributeList) == 0:
return CalMostClass(classInfo)
for key, valLen in classInfo.items():
if valLen == LEN_DATASET:
return key
break
temp = Dataset[0][:-1]
sameCnt = 0
for dataItem in Dataset:
if temp == dataItem[:-1]:
sameCnt += 1
if sameCnt == LEN_DATASET:
return CalMostClass(classInfo)
theBestAttrIndex = 0
theBestAttrGain = 0
theBestAttrPartition = {}
for attributeIndex in range(len(attributeList)):
gain, attributePartition = computeAttrGainNPartition(Dataset, attributeIndex)
if gain > theBestAttrGain:
theBestAttrGain = gain
theBestAttrIndex = attributeIndex
theBestAttrPartition = attributePartition
attrName = attributeList[theBestAttrIndex]
del (attributeList[theBestAttrIndex])
for key, valList in theBestAttrPartition.items():
for index in range(len(valList)):
temp = valList[index][:theBestAttrIndex]
temp.extend(valList[index][theBestAttrIndex + 1:])
valList[index] = temp
Tree[attrName] = {}
for keyAttrVal, valDataset in theBestAttrPartition.items():
subLabels = attributeList[:]
Tree[attrName][keyAttrVal] = CreateDecisionTree(valDataset, subLabels)
return Tree
def Predict(DataSet, testArrtList, decisionTree):
predicted_label = []
for dataItem in DataSet:
cur_decisionTree = decisionTree
if type(cur_decisionTree) == set:
node = list(cur_decisionTree)
else:
node = list(cur_decisionTree.keys())[0]
while node in testArrtList:
cur_index = testArrtList.index(node)
cur_element = dataItem[cur_index]
cur_decisionTree = cur_decisionTree[node][cur_element]
if type(cur_decisionTree) == dict:
node = list(cur_decisionTree.keys())[0]
else:
node = cur_decisionTree
predicted_label.append(node)
return predicted_label
def SaveModel(decisionTree, filename):
f = open(filename, 'wb')
pickle.dump(decisionTree, f)
def LoadModel(filename):
f = open(filename, 'rb')
return pickle.load(f)
if __name__ == '__main__':
base = os.path.dirname(os.path.abspath(__file__))
trainingDataset, attributeList = CreateTrainingDataset()
testDataset, testArrtList = CreateTestDataset()
path = base + "/DecisionTreeModel.txt"
print(path)
decisionTree = CreateDecisionTree(trainingDataset, attributeList)
SaveModel(decisionTree, path)
model = LoadModel(path)
print(model)
result = Predict(testDataset, testArrtList ,model)
print(result)
treePlotter.createPlot(model)
treePlotter.py
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
xytext=centerPt, textcoords='axes fraction', \
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = getTreeDepth(secondDict[key]) + 1
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalw = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalw
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()