id3 decision tree

本文详细介绍了一种经典的机器学习算法——ID3决策树,并通过Python代码实现了从数据准备到决策树构建的全过程。该文首先定义了决策树的基本结构和分类流程,随后通过具体的天气数据集展示了如何利用信息熵和信息增益来选择划分属性,最终生成一棵决策树。
摘要由CSDN通过智能技术生成


#coding:utf-8
from math import log
from collections import Counter
import copy
import collections

#原始数据
data='''sunny       hot             high           false       N
    sunny       hot             high           true         N
    overcast   hot             high           false       P
    rain           mild           high           false       P
    rain           cool           normal      false       P
    rain           cool           normal      true         N
    overcast   cool           normal      true         P
    sunny      mild           high           false       N
    sunny       cool           normal      false       P
    rain           mild           normal      false       P
    sunny       mild           normal      true         P
    overcast   mild           high           true         P
    overcast   hot             normal      false       P
    rain           mild           high           true        N'''

features = ['outlook','temperature','humidity','windy','play']

class BaseClassify:
	def __init__(self, name, data, feature_list):
		self.name = name
		self.data = data
		self.target_name = feature_list[-1]
		self.feature_list = feature_list

	def __str__(self):
		return 'name=%12s,label_list=%s' % (self.name, self.feature_list)

	def get_data(self, source_str, num_of_feature, label_list):
		self.label_list = copy.copy(label_list)
		data_list = source_str.split()
		ret = []

		for i in range( len(data_list) / num_of_feature):
			start = num_of_feature * i
			end = num_of_feature * ( i + 1) - 1
			data = {}
			for k in range( start , end + 1):
				data[label_list[k - start]] = data_list[k]
			ret.append(data)
		self.data_list = copy.deepcopy(ret)
		self.data = copy.deepcopy(ret)
		return ret

	def do_test(self):
		pass

#
class TreeNode:
	def __init__(self, name, level, data_list, feature_list):
		self.name = name
		self.node_list = []
		self.level = level
		self.data_list = data_list
		self.feature_list = feature_list
		self.is_leaf = False
		self.target_class = None

	def add_node(self, node):
		#print 'add_node', node
		self.node_list.append(node)

	def __str__(self):
		return '%s=%s=%s ' % (self.name, str(self.node_list), str(self.level))

	def __repr__(self):
		return '%s=%s=%s ' % (self.name, str(self.node_list), str(self.level))


class Id3Classify(BaseClassify):
	def do_test(self):
		pass

	def get_distinct_feature_vals(self, feature, data_list):
		ret = set()
		for rc in data_list:
			ret.add(rc[feature])
		return ret

	def get_val_by_feature(self, feature, data_list):
		ret = []
		for rc in data_list:
			ret.append(rc[feature])
		return Counter(ret)

	def cal_entropy(self, data_list):
		target_name = self.target_name
		tartget_map = self.get_val_by_feature(target_name, data_list)
		entropy = 0.0
		for key, val in tartget_map.items():
			prob = 1.0 * val / len(data_list) * 1.0
			entropy += -1.0 * prob * log(prob)
		return entropy

	def filter_by_feature_val(self, feature_val, feature_key, data_list):
		ret = []
		for data in data_list:
			if data[feature_key] == feature_val:
				ret.append(data)
		return ret

	def split_feature_func(self, feature, feature_list):
		ret =[]
		for f in feature_list:
			if feature == f:
				continue
			ret.append(f)
		return ret

	def split_by_max_gain(self, feature_list, data_list):
		target_name = self.target_name

		all_entropy = self.cal_entropy(data_list)

		min_subset = log(len(data_list), 2)
		min_feature  = feature_list[0]
		for feature in feature_list:
			if feature == target_name:
				continue
			sub_entroy = 0
			feature_counter = self.get_val_by_feature(feature, data_list)
			for key , val in feature_counter.items():
				sub_entroy += 1.0 * val / len( data_list) * self.cal_entropy( self.filter_by_feature_val(key, feature, data_list))
			if sub_entroy < min_subset:
				min_subset = sub_entroy
				min_feature = feature
		return min_feature

	def split_data(self, split_feature, feature, feature_list, data_list):
		ret =[]
		for rc in data_list:
			if rc[split_feature] != feature:
				continue
			cp = dict()
			for f in feature_list:
				if f == split_feature:
					continue
				cp[f] = rc[f]
			ret.append(cp)
		return ret

	def is_all_same_target(self,data_list):
		ct_map = self.get_val_by_feature(self.target_name, data_list)
		if len(ct_map) == 1:
			return (True, data_list[0][self.target_name])
		return (False,None)

	def get_max_target(self, data_list):
		ct_map = self.get_val_by_feature(self.target_name, data_list)
		ct_map.most_common(1)
		return ct_map[ct_map.keys()[0]]

	def bfs_create_tree(self, root, feature_list, data_list, level):
		if len( feature_list) == 2 :
			root.is_leaf = True
			root.target_class = self.get_max_target(data_list)
			return

		split_feature = self.split_by_max_gain(feature_list, data_list)
		feature_counter = self.get_val_by_feature(split_feature, data_list)
		for key, val in feature_counter.items():
			child_feature_list = self.split_feature_func(split_feature, feature_list)
			child_data_list = self.split_data(split_feature, key, feature_list,data_list)
			is_all_same,target =self.is_all_same_target(child_data_list)
			
			child_node = TreeNode(str(level) + key + '-' + root.name, level + 1, child_data_list, child_feature_list)

			child_node.is_leaf = is_all_same
			if is_all_same == True:
				child_node.target_class = target
			else:
				self.bfs_create_tree(child_node, child_feature_list, child_data_list, level + 1)
			root.add_node(child_node)
		return

	def create_tree(self):
		root = TreeNode('root', 0, self.data_list, self.feature_list)
		self.bfs_create_tree(root, root.feature_list, root.data_list, root.level)
		return root

	def dfs_visit(self, node):
		if node == None:
			return

		print 'name=[%s],level=[%s], isLeaf=[%s]' % (node.name, str(node.level), str(node.is_leaf))
		#print node.data_list
		print node.feature_list
		print node.target_class

		if node.node_list is not None:
			for n in node.node_list:
				self.bfs_visit(n)
		return

	def bfs_visit(self, node):
		
		dq = collections.deque()
		dq.append(node)

		while True:
			try:
				ele = dq.pop()
				if ele is None:
					break
				print 'name=[%s],level=[%s], isLeaf=[%s]' % (ele.name, str(ele.level), str(ele.is_leaf))
				#print node.data_list
				print ele.feature_list
				print ele.target_class
				for n in ele.node_list:
					dq.append(n)
			except Exception, e:
				break
		return


if __name__ == '__main__':
	classify = Id3Classify('defaultClassify',None, features)
	classify.get_data(data, 5, features)
	root = classify.create_tree()
	#classify.dfs_visit(root)
	classify.bfs_visit(root)

参考:

http://blog.csdn.net/acdreamers/article/details/44661149

http://blog.csdn.net/zhaoyl03/article/details/8665663


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值