#coding:utf-8
from math import log
from collections import Counter
import copy
import collections
#原始数据
data='''sunny hot high false N
sunny hot high true N
overcast hot high false P
rain mild high false P
rain cool normal false P
rain cool normal true N
overcast cool normal true P
sunny mild high false N
sunny cool normal false P
rain mild normal false P
sunny mild normal true P
overcast mild high true P
overcast hot normal false P
rain mild high true N'''
features = ['outlook','temperature','humidity','windy','play']
class BaseClassify:
def __init__(self, name, data, feature_list):
self.name = name
self.data = data
self.target_name = feature_list[-1]
self.feature_list = feature_list
def __str__(self):
return 'name=%12s,label_list=%s' % (self.name, self.feature_list)
def get_data(self, source_str, num_of_feature, label_list):
self.label_list = copy.copy(label_list)
data_list = source_str.split()
ret = []
for i in range( len(data_list) / num_of_feature):
start = num_of_feature * i
end = num_of_feature * ( i + 1) - 1
data = {}
for k in range( start , end + 1):
data[label_list[k - start]] = data_list[k]
ret.append(data)
self.data_list = copy.deepcopy(ret)
self.data = copy.deepcopy(ret)
return ret
def do_test(self):
pass
#
class TreeNode:
def __init__(self, name, level, data_list, feature_list):
self.name = name
self.node_list = []
self.level = level
self.data_list = data_list
self.feature_list = feature_list
self.is_leaf = False
self.target_class = None
def add_node(self, node):
#print 'add_node', node
self.node_list.append(node)
def __str__(self):
return '%s=%s=%s ' % (self.name, str(self.node_list), str(self.level))
def __repr__(self):
return '%s=%s=%s ' % (self.name, str(self.node_list), str(self.level))
class Id3Classify(BaseClassify):
def do_test(self):
pass
def get_distinct_feature_vals(self, feature, data_list):
ret = set()
for rc in data_list:
ret.add(rc[feature])
return ret
def get_val_by_feature(self, feature, data_list):
ret = []
for rc in data_list:
ret.append(rc[feature])
return Counter(ret)
def cal_entropy(self, data_list):
target_name = self.target_name
tartget_map = self.get_val_by_feature(target_name, data_list)
entropy = 0.0
for key, val in tartget_map.items():
prob = 1.0 * val / len(data_list) * 1.0
entropy += -1.0 * prob * log(prob)
return entropy
def filter_by_feature_val(self, feature_val, feature_key, data_list):
ret = []
for data in data_list:
if data[feature_key] == feature_val:
ret.append(data)
return ret
def split_feature_func(self, feature, feature_list):
ret =[]
for f in feature_list:
if feature == f:
continue
ret.append(f)
return ret
def split_by_max_gain(self, feature_list, data_list):
target_name = self.target_name
all_entropy = self.cal_entropy(data_list)
min_subset = log(len(data_list), 2)
min_feature = feature_list[0]
for feature in feature_list:
if feature == target_name:
continue
sub_entroy = 0
feature_counter = self.get_val_by_feature(feature, data_list)
for key , val in feature_counter.items():
sub_entroy += 1.0 * val / len( data_list) * self.cal_entropy( self.filter_by_feature_val(key, feature, data_list))
if sub_entroy < min_subset:
min_subset = sub_entroy
min_feature = feature
return min_feature
def split_data(self, split_feature, feature, feature_list, data_list):
ret =[]
for rc in data_list:
if rc[split_feature] != feature:
continue
cp = dict()
for f in feature_list:
if f == split_feature:
continue
cp[f] = rc[f]
ret.append(cp)
return ret
def is_all_same_target(self,data_list):
ct_map = self.get_val_by_feature(self.target_name, data_list)
if len(ct_map) == 1:
return (True, data_list[0][self.target_name])
return (False,None)
def get_max_target(self, data_list):
ct_map = self.get_val_by_feature(self.target_name, data_list)
ct_map.most_common(1)
return ct_map[ct_map.keys()[0]]
def bfs_create_tree(self, root, feature_list, data_list, level):
if len( feature_list) == 2 :
root.is_leaf = True
root.target_class = self.get_max_target(data_list)
return
split_feature = self.split_by_max_gain(feature_list, data_list)
feature_counter = self.get_val_by_feature(split_feature, data_list)
for key, val in feature_counter.items():
child_feature_list = self.split_feature_func(split_feature, feature_list)
child_data_list = self.split_data(split_feature, key, feature_list,data_list)
is_all_same,target =self.is_all_same_target(child_data_list)
child_node = TreeNode(str(level) + key + '-' + root.name, level + 1, child_data_list, child_feature_list)
child_node.is_leaf = is_all_same
if is_all_same == True:
child_node.target_class = target
else:
self.bfs_create_tree(child_node, child_feature_list, child_data_list, level + 1)
root.add_node(child_node)
return
def create_tree(self):
root = TreeNode('root', 0, self.data_list, self.feature_list)
self.bfs_create_tree(root, root.feature_list, root.data_list, root.level)
return root
def dfs_visit(self, node):
if node == None:
return
print 'name=[%s],level=[%s], isLeaf=[%s]' % (node.name, str(node.level), str(node.is_leaf))
#print node.data_list
print node.feature_list
print node.target_class
if node.node_list is not None:
for n in node.node_list:
self.bfs_visit(n)
return
def bfs_visit(self, node):
dq = collections.deque()
dq.append(node)
while True:
try:
ele = dq.pop()
if ele is None:
break
print 'name=[%s],level=[%s], isLeaf=[%s]' % (ele.name, str(ele.level), str(ele.is_leaf))
#print node.data_list
print ele.feature_list
print ele.target_class
for n in ele.node_list:
dq.append(n)
except Exception, e:
break
return
if __name__ == '__main__':
classify = Id3Classify('defaultClassify',None, features)
classify.get_data(data, 5, features)
root = classify.create_tree()
#classify.dfs_visit(root)
classify.bfs_visit(root)
参考:
http://blog.csdn.net/acdreamers/article/details/44661149
http://blog.csdn.net/zhaoyl03/article/details/8665663