# 决策树剪枝中的损失函数的实现

#!/usr/bin/python
#-*-coding:utf-8 -*-
#决策树的剪枝算法
import ID3alogorithem as id3
from math import log
#  myTree数据类型 {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
#计算每个结点的经验熵
def eachNodeEntropy():
dataSet,labels=id3.createDataSET()
myTree=id3.createTree(dataSet,labels)
i,j=0,0
label=""
node=(i,j,label)
#存储所有节点个数
nodeList=[]
return 0
#得到所有叶节点
# def getNodeList(dataSet):
#     firstDict=dataSet[0]

#找到当前节点的路径
#思考一部分，写一部分，不要想一下子写完
#{0: {'no surfacing': 0}, 1: {'no surfacing': 1, 'flippers': 0}, 2: {'no surfacing': 1, 'flippers': 1}}
#
def getNodePath(dataSet,path,i,allPaths):
firstFeature=dataSet.keys()[0]
secondDict=dataSet[firstFeature]
#只在最开始，初始化paths
if firstFeature=="no surfacing":
paths={}
i=0
allPaths={}
else :
paths=path
# print(str(dataSet)+"dataSet")
for key in secondDict.keys():
if type(secondDict[key]).__name__=="dict":
path[firstFeature]=key
getNodePath(secondDict[key],path,i,allPaths)
else:
# 保存当前的 特征以及特征值
path[firstFeature]=key
paths=dict(paths.items()+path.items())
# print(paths)
allPaths[i]=paths
i=i+1
path={}
# print(allPath)
return allPaths

def getfeatureIndex(labels,feature):
i=-1
# print(labels)
for label in labels:
i+=1
if label==feature:
# print(i)
return i

# 由路径找到当前的叶节点的样本点
def searchSample(path,dataSet,labels):
feature=[]
feaValue=[]
for item,value in path.items():
# print(item+"-----"+str(value))
feature.append(item)
feaValue.append(value)
# print(str(feature)+"feaValue  "+str(feaValue))
i=0
while i<len(feature):
fea=getfeatureIndex(labels,feature[i])
# print(str(feature[i])+" dataSet"+str(dataSet)+"----"+str(feaValue[i]))
dataSet=id3.splitDataSet(dataSet,fea,feaValue[i])
#没有把根数据破坏掉
nowData=[]
before=labels[:fea]
after=labels[fea+1:]
nowData.extend(before)
nowData.extend(after)
labels=nowData
i+=1
# print(i)
return dataSet

#得到损失函数的值
def getLossMethod(paremeter):
dataSet,labels=id3.createDataSET()
#在这里labels 被改变了
allPaths=getNodePath(id3.createTree(dataSet,labels),{},0,{})
# print(allPaths[0])
dataSet,labels2=id3.createDataSET()
print(allPaths)
CT=0.0
method=0.0
for key,path in allPaths.items():
print(str(path)+" path")
nowdata=searchSample(path,dataSet,labels2)
#叶节点的样本点
print(nowdata)
#得到损失函数值
CT+=getCT(nowdata)
method=CT+paremeter*list(allPaths.keys())[-1]
print(str(method)+" methodValues")

#得到C（T）的值
def getCT(dataSet):
numData=len(dataSet)
MarkList=[data[-1] for data in dataSet]
uniqueList=set(MarkList)
ratio=0.0
for value in uniqueList:
numTK=0
for data in dataSet:
if value==data[-1]:
numTK+=1
ratio+=numTK*(log(numTK/numData,2))
print("ratio "+str(ratio))
return ratio
getLossMethod(0.3)