import numpy as np
import pandas as pd
from sklearn.utils.multiclass import type_of_target
from decision_tree import treePlottter
class Node(object):
def __init__(self):
self.feature_name = None # 特性的名称
self.feature_index = None # 特性的下标
self.subtree = {} #树节点的集合
self.impurity = None #信息此节点的信息增益
self.is_continuous = False #是否为连续值
self.split_value = None #连续值时的划分依据
self.is_leaf = False #是否为叶子节点
self.leaf_class = None #叶子节点对应的类
self.leaf_num = 0 # 叶子数目
self.high = -1 # 树的高度
def entroy(y):
p = pd.value_counts(y) / y.shape[0] # 计算各类样本所占比率
ent = np.sum(-p * np.log2(p))
return ent
return node
def info_gain(feature, y, entD, is_continuous=False):
'''
计算信息增益
------
:param feature: 当前特征下所有样本值
:param y: 对应标签值
:return: 当前特征的信息增益, list类型,若当前特征为离散值则只有一个元素为信息增益,若为连续值,则第一个元素为信息增益,第二个元素为切分点
'''
m = y.shape[0]
unique_value = pd.unique(feature)
if is_continuous:
unique_value.sort() # 排序, 用于建立分割点
split_point_set = [(unique_value[i] + unique_value[i + 1]) / 2 for i in range(len(unique_value) - 1)]
min_ent = float('inf') # 挑选信息熵最小的分割点
min_ent_point = None
for split_point_ in split_point_set:
Dv1 = y[feature <= split_point_]
Dv2 = y[feature > split_point_]
feature_ent_ = Dv1.shape[0] / m * entroy(Dv1) + Dv2.shape[0] / m * entroy(Dv2)
if feature_ent_ < min_ent:
min_ent = feature_ent_
min_ent_point = split_point_
gain = entD - min_ent
return [gain, min_ent_point]
else:
feature_ent = 0
for value in unique_value:
Dv = y[feature == value] # 当前特征中取值为 value 的样本,即书中的 D^{v}
feature_ent += Dv.shape[0] / m * entroy(Dv)
gain = entD - feature_ent # 原书中4.2式
return [gain]
def choose_best_feature_infogain(X, y):
'''
以返回值中best_info_gain 的长度来判断当前特征是否为连续值,若长度为 1 则为离散值,若长度为 2 , 则为连续值
:param X: 当前所有特征的数据 pd.DaraFrame格式
:param y: 标签值
:return: 以信息增益来选择的最佳划分属性,第一个返回值为属性名称,
'''
features = X.columns
best_feature_name = None
best_info_gain = [float('-inf')]
entD = entroy(y)
for feature_name in features:
is_continuous = type_of_target(X[feature_name]) == 'continuous'
infogain = info_gain(X[feature_name], y, entD, is_continuous)
if infogain[0] > best_info_gain[0]:
best_feature_name = feature_name
best_info_gain = infogain
return best_feature_name, best_info_gain
def generate(X,y,columns):
node = Node()
# Pandas.Series.nunique()统计不同值的个数
if y.nunique() == 1: # 属于同一类别
node.is_leaf = True
node.leaf_class = y.values[0]
node.high = 0
node.leaf_num += 1
return node
if X.empty: # 特征用完了,数据为空,返回样本数最多的类
node.is_leaf = True
node.leaf_class = pd.value_counts(y).index[0] # 返回样本数最多的类
node.high = 0
node.leaf_num += 1
return node
best_feature_name, best_impurity = choose_best_feature_infogain(X, y)
node.feature_name = best_feature_name
node.impurity = best_impurity[0]
node.feature_index = columns.index(best_feature_name)
feature_values = X.loc[:, best_feature_name]
if len(best_impurity) == 1: # 离散值
node.is_continuous = False
unique_vals = pd.unique(feature_values)
sub_X = X.drop(best_feature_name, axis=1)
max_high = -1
for value in unique_vals:
node.subtree[value] = generate(sub_X[feature_values == value], y[feature_values == value],columns)
if node.subtree[value].high > max_high: # 记录子树下最高的高度
max_high = node.subtree[value].high
node.leaf_num += node.subtree[value].leaf_num
node.high = max_high + 1
elif len(best_impurity) == 2: # 连续值
node.is_continuous = True
node.split_value = best_impurity[1]
up_part = '>= {:.3f}'.format(node.split_value)
down_part = '< {:.3f}'.format(node.split_value)
node.subtree[up_part] = generate(X[feature_values >= node.split_value],
y[feature_values >= node.split_value],columns)
node.subtree[down_part] = generate(X[feature_values < node.split_value],
y[feature_values < node.split_value],columns)
node.leaf_num += (node.subtree[up_part].leaf_num + node.subtree[down_part].leaf_num)
node.high = max(node.subtree[up_part].high, node.subtree[down_part].high) + 1
return node
if __name__ == "__main__":
data = pd.read_csv("西瓜3.0.txt", index_col=0) # index_col参数设置第一列作为index
#不带第一列,求得西瓜的属性
x = data.iloc[:, :8] #<class 'pandas.core.frame.DataFrame'>
y = data.iloc[:, 8] #<class 'pandas.core.series.Series'>
columns_name = list(x.columns) # 包括原数据的列名
node = generate(x,y,columns_name)
treePlottter.create_plot(node)
另一个画图算法,这是另一个py文件
from matplotlib import pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
decision_node = dict(boxstyle='round,pad=0.3', fc='#FAEBD7')
leaf_node = dict(boxstyle='round,pad=0.3', fc='#F4A460')
arrow_args = dict(arrowstyle="<-")
y_off = None
x_off = None
total_num_leaf = None
total_high = None
def plot_node(node_text, center_pt, parent_pt, node_type, ax_):
ax_.annotate(node_text, xy=[parent_pt[0], parent_pt[1] - 0.02], xycoords='axes fraction',
xytext=center_pt, textcoords='axes fraction',
va="center", ha="center", size=15,
bbox=node_type, arrowprops=arrow_args)
def plot_mid_text(mid_text, center_pt, parent_pt, ax_):
x_mid = (parent_pt[0] - center_pt[0]) / 2 + center_pt[0]
y_mid = (parent_pt[1] - center_pt[1]) / 2 + center_pt[1]
ax_.text(x_mid, y_mid, mid_text, fontdict=dict(size=10))
def plot_tree(my_tree, parent_pt, node_text, ax_):
global y_off
global x_off
global total_num_leaf
global total_high
num_of_leaf = my_tree.leaf_num
center_pt = (x_off + (1 + num_of_leaf) / (2 * total_num_leaf), y_off)
plot_mid_text(node_text, center_pt, parent_pt, ax_)
if total_high == 0: # total_high为零时,表示就直接为一个叶节点。因为西瓜数据集的原因,在预剪枝的时候,有时候会遇到这种情况。
plot_node(my_tree.leaf_class, center_pt, parent_pt, leaf_node, ax_)
return
plot_node(my_tree.feature_name, center_pt, parent_pt, decision_node, ax_)
y_off -= 1 / total_high
for key in my_tree.subtree.keys():
if my_tree.subtree[key].is_leaf:
x_off += 1 / total_num_leaf
plot_node(str(my_tree.subtree[key].leaf_class), (x_off, y_off), center_pt, leaf_node, ax_)
plot_mid_text(str(key), (x_off, y_off), center_pt, ax_)
else:
plot_tree(my_tree.subtree[key], center_pt, str(key), ax_)
y_off += 1 / total_high
def create_plot(tree_):
global y_off
global x_off
global total_num_leaf
global total_high
total_num_leaf = tree_.leaf_num
total_high = tree_.high
y_off = 1
x_off = -0.5 / total_num_leaf
fig_, ax_ = plt.subplots()
ax_.set_xticks([]) # 隐藏坐标轴刻度
ax_.set_yticks([])
ax_.spines['right'].set_color('none') # 设置隐藏坐标轴
ax_.spines['top'].set_color('none')
ax_.spines['bottom'].set_color('none')
ax_.spines['left'].set_color('none')
plot_tree(tree_, (0.5, 1), '', ax_)
plt.show()