决策树
最近研读了机器学习(周志华)的第四章决策树,在此做点小笔记。
基本概念
决策树,顾名思义,就是一棵用于做决策的树,其实我觉得就是个分类器。生成一棵决策树的基本思路很简单,用贪婪的方法不断降低分类对象的混乱度,或者说不断提高纯度(purity)。
用来度量混乱度或纯度的方法有很多,其中比较经典和简单的方法是信息熵,
Entropy(t)=−∑k=1|Y|pklog2pk
相信学过信息论或通信原理的同学对信息熵的概念应该非常熟悉了,熵越大,不确定性越大,也就越混乱。决策树要做的就是降低不确定性,定义信息增益(Information Gain),
Gain(D,a)=Ent(D)−∑v=1V|Dv||D|Ent(Dv)
其中, D 是数据集Data set,
实现
下面是对课后习题4.3的实现,包括对连续值的处理。详见我的GitHub。
decision_tree.py
COLOR = 0
ROOT = 1
SOUND = 2
TEXTURE = 3
NAVEL = 4
TOUCH = 5
DENSITY = 6
SUGAR = 7
from aenum import Enum
from copy import deepcopy
import numpy as np
class Attribute(Enum):
COLOR = 0
ROOT = 1
SOUND = 2
TEXTURE = 3
NAVEL = 4
TOUCH = 5
DENSITY = 6
SUGAR = 7
class Node(object):
"""
Tree node class.
"""
def __init__(self):
self.attri = None # Internal node has attribute.
self.isLeaf = False # Is leaf node or not ?
self.decision = False # Final decision: T or F.
self.parent = None
self.children = [] # A list store children.
def set_leaf(self, positive):
"""
:param positive: bool
:return:
"""
self.isLeaf = True
self.decision = positive
class DataB(object):
"""
Data management base class.
"""
def __init__(self, data_set, y, idx):
"""
:param data_set: List(List(int))
:param y: List(int)
:param idx: List(int)
"""
self.data = data_set
self.y = y
self.idx = idx
def filter(self, ai, val):
"""
:param ai: int
:param val: int
:return: DataB()
"""
filter_idx = [x for x in self.idx if self.data[x][ai] == val]
return DataB(self.data, self.y, filter_idx)
def empty(self):
if len(self.idx) == 0:
return True
return False
def is_positive(self):
if len(self.idx) == sum([self.y[i] for i in self.idx]):
return True
return False
def is_negative(self):
if sum([self.y[i] for i in self.idx]) == 0:
return True
return False
def mark_most(self):
if self.empty():
return None
num_pos = sum([self.y[i] for i in self.idx])
num_neg = len(self.idx) - num_pos
print("num_pos: %d, num_neg: %d" %(num_pos, num_neg))
if num_pos >= num_neg:
return True
return False
class Decisiontree(object):
"""
Decision Tree base class
"""
def __init__(self, attri_list):
self.attri_list = attri_list
def find_best(self, data, attri_set):
best = list(attri_set)[0]
return best
def tree_gen(self, data, attri_set):
"""
Recursive function use to generate decision tree.
:param data: DataB()
:param attri_set: set()
:return:
"""
# Create a new node.
newNode = Node()
# If data set is already classified, return a leaf node.
if data.is_positive():
newNode.set_leaf(True)
return newNode
elif data.is_negative():
newNode.set_leaf(False)
return newNode
# If attribute set is empty, can't be classified.
if not attri_set:
type = data.mark_most()
newNode.set_leaf(type)
return newNode
# Find a best decision attribute.
# If it is a continuous attribute, it should have a best mid point.
choice, midpoint = self.find_best(data, attri_set)
if choice == -1:
print "error"
return None
print "best choice:", Attribute(choice), midpoint
newNode.attri = Attribute(choice)
# Create a new attribute set,
# which doesn't contain the best choice just find.
new_attri_set = deepcopy(attri_set)
new_attri_set.remove(choice)
# Create branches.
for val in self.attri_list[choice]:
data_v = data.filter(choice, val, midpoint=midpoint)
if data_v.empty():
# If branch has empty data, create a leaf child.
childNode = Node()
childNode.set_leaf(data.mark_most()) # set parent's most
newNode.children.append(childNode)
else:
# Recursively generate decision child tree.
childNode = self.tree_gen(data_v, new_attri_set)
newNode.children.append(childNode)
return newNode
info_gain.py
from decision_tree import Decisiontree
from decision_tree import DataB
from math import log
import numpy as np
import preprocess
COLOR = 0
ROOT = 1
SOUND = 2
TEXTURE = 3
NAVEL = 4
TOUCH = 5
DENSITY = 6
SUGAR = 7
discrete_set = set([COLOR, ROOT, SOUND, TEXTURE, NAVEL, TOUCH])
continuous_set = set([SUGAR, DENSITY])
class Data(DataB):
"""
Data set access interface.
"""
def filter(self, ai, val, midpoint=None):
"""
Override filter.
:param ai: int
:param val: int
:param midpoint: float
:return: Data()
"""
if midpoint is not None:
if val == 0:
filter_idx = [x for x in self.idx if self.data[x][ai] <= midpoint]
elif val == 1:
filter_idx = [x for x in self.idx if self.data[x][ai] > midpoint]
else:
filter_idx = [x for x in self.idx if self.data[x][ai] == val]
return Data(self.data, self.y, filter_idx)
def get_filter_idx(self, ai, av, is_cont=False):
"""
:param ai: int
:param av: int
:param is_cont: bool
:return: List(int)
"""
if is_cont:
return [x for x in self.idx if self.data[x][ai] < av]
return [x for x in self.idx if self.data[x][ai] == av]
def num_positive(self):
"""
:return: int
"""
return sum([self.y[i] for i in self.idx])
def num_positive_v(self, ai, av, is_cont=False):
"""
:param ai: int
:param av: int
:param is_cont: bool
:return: int
"""
filter_list = self.get_filter_idx(ai, av, is_cont)
return sum([self.y[i] for i in filter_list])
def num_negative(self):
"""
:return: int
"""
return len(self.idx) - sum([self.y[i] for i in self.idx])
def num_negative_v(self, ai, av, is_cont=False):
"""
:param ai: int
:param av: int
:return: int
"""
filter_list = self.get_filter_idx(ai, av, is_cont)
return len(filter_list) - sum([self.y[i] for i in filter_list])
class DecisionTreeInfoGain(Decisiontree):
"""
Decision Tree calculate by Information Gain.
Mainly override the find_best method.
"""
def __init__(self, attri_list, midval=None):
"""
:param attri_list: Dict()
:param midval: Dict()
:return:
"""
super(self.__class__, self).__init__(attri_list)
self.midval = midval
def find_best(self, data, attri_set):
"""
Override find_best method
:param data: Data()
:param attri_set: Set(int)
:return: int
"""
# Best choice with max info gain.
best = -1
# Calculate max info gain.
max_gain = 0
# If continuous attribute, find a best mid break point.
point = None
# Entropy of original.
s0 = data.num_positive() + data.num_negative()
e0 = self.entropy(data.num_positive(), data.num_negative())
print("Current entropy %f" %e0)
# Calculate entropy of different branches.
for ai in attri_set:
ei = 0
# Check whether ai is continuous attribute.
if ai in discrete_set:
# For discrete attributes.
for av in self.attri_list[ai]:
s1 = data.num_positive_v(ai, av)
s2 = data.num_negative_v(ai, av)
if s1 + s2 == 0:
continue
ei += (s1+s2)/float(s0) * self.entropy(s1, s2)
info_gain = e0 - ei
if info_gain > max_gain:
max_gain = info_gain
best = ai
else:
# For continuous attributes.
sp_t = data.num_positive()
sn_t = data.num_negative()
for av in self.midval[ai]:
ei = 0
sp_l = data.num_positive_v(ai, av, True)
sn_l = data.num_negative_v(ai, av, True)
if sp_l + sn_l != 0:
ei += (sp_l+sn_l)/float(s0) * self.entropy(sp_l, sn_l)
sp_r = sp_t - sp_l
sn_r = sn_t - sn_l
if sp_r + sn_r != 0:
ei += (sp_r+sn_r)/float(s0) * self.entropy(sp_r, sn_r)
info_gain = e0 - ei
if info_gain > max_gain:
max_gain = info_gain
best = ai
# Remember to record mid point.
point = av
if point is not None:
print("Mid Point: %f" % point)
return best, point
@staticmethod
def entropy(s1, s2):
"""
Entropy calculator.
:param s1: int
:param s2: int
:return: float
"""
p1 = float(s1) / (s1 + s2)
p2 = float(s2) / (s1 + s2)
if p1 == 0 or p2 == 0:
return 0
try:
# Only two label: T or F.
ent = (- p1 * log(p1, 2) - p2 * log(p2, 2))
except:
ent = 0
return ent
if __name__ == '__main__':
# Combine discrete and continuous attributions.
attri_set = discrete_set.union(continuous_set)
# All possible branches.
attri_list = {
COLOR: [0, 1, 2],
ROOT: [0, 1, 2],
SOUND: [0, 1, 2],
TEXTURE: [0, 1, 2],
NAVEL: [0, 1, 2],
TOUCH: [0, 1],
DENSITY: [0, 1],
SUGAR: [0, 1]
}
# Load data from text file.
data_file = "data_set3.0.txt"
data_set = np.loadtxt(data_file, dtype=np.float16, delimiter=',')
# Process continuous attribution val.
midval = dict()
for ai in continuous_set:
val_list = preprocess.cont2mid(data_set, ai)
midval[ai] = val_list
# Data label.
y = [1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0]
# Data index.
idx = [i for i in xrange(len(y))]
# Create data object to manage data.
data = Data(data_set, y, idx)
# Create decision tree object
dt = DecisionTreeInfoGain(attri_list, midval=midval)
# Generate decision tree.
root = dt.tree_gen(data, attri_set)
# Print the decision tree (BFS).
print("Travel tree, breath first search")
q = [root]
while len(q) > 0:
root = q.pop(0)
if root.isLeaf:
print("Good or Bad", root.decision)
else:
print("Choice: ", root.attri)
for child in root.children:
q.append(child)
preprocess.py
import numpy as np
def cont2mid(data_set, ai):
"""
Preprocess of continuous attributes.
:param data_set: Array
:param ai: int
:return: List()
"""
col = list(data_set[:, ai])
col.sort()
t = [(col[i]+col[i+1])/2 for i in xrange(len(col)-1)]
return t
data_set3.0.txt
0, 0, 0, 0, 0, 0, 0.697, 0.460
1, 0, 1, 0, 0, 0, 0.774, 0.376
1, 0, 0, 0, 0, 0, 0.634, 0.264
0, 0, 1, 0, 0, 0, 0.608, 0.318
2, 0, 0, 0, 1, 0, 0.556, 0.215
0, 1, 0, 0, 1, 1, 0.403, 0.237
1, 1, 0, 1, 1, 1, 0.481, 0.149
1, 1, 0, 0, 1, 0, 0.437, 0.211
1, 1, 1, 1, 1, 0, 0.666, 0.091
0, 2, 2, 0, 2, 1, 0.243, 0.267
2, 2, 2, 2, 2, 0, 0.245, 0.057
2, 1, 0, 2, 2, 1, 0.343, 0.099
0, 1, 0, 1, 0, 0, 0.639, 0.161
2, 1, 1, 1, 0, 0, 0.657, 0.198
1, 1, 0, 0, 1, 1, 0.360, 0.370
2, 0, 0, 2, 2, 0, 0.593, 0.042
0, 0, 1, 1, 1, 0, 0.719, 0.103