python 分类树_python实现ID3分类树

import numpy as np

import pandas as pd

from load_data import inputs

from scipy.stats import mode

class ID3(object):

def __init__(self):

self.successors = [] # The node's successors

self.attribute = None # Attribute used for splitting.

self.class_value = None # Class value if node is leaf

self.class_distribution = None # Class distribution if node is leaf

def buildTree(self, X, y):

self.class_value = mode(y)[0][0]

self.class_distribution = pd.value_counts(y).to_dict()

self.attribute, info_gain = compute_best_attribute(X, y)

if (info_gain == 0.0):

return # 停止分枝

data_list = split_data(X, y, self.attribute)

for sub_X, sub_y in data_list:

sub_tree = ID3()

sub_tree.buildTree(sub_X, sub_y)

self.successors.append(sub_tree)

def split_data(X, y, attribute):

result = []

distinct_values = np.unique(X[:, attribute])

for v in distinct_values:

idx = X[:, attribute] == v

result.append((X[idx], y[idx]))

return result

def entropy(vec=None, freq=None):

"""根据频次和频率都可以计算信息熵"""

if vec is not None:

_, counts = np.unique(vec, return_counts=True)

freq = counts / vec.size

return -freq.dot(np.log2(freq))

def compute_info_gain(vec, labels):

total_counts = vec.size

values, counts = np.unique(vec, return_counts=True)

value_freqs = counts / total_counts

IG = 0

for v, freq in zip(values, value_freqs):

IG += freq * entropy(labels[vec == v])

return entropy(labels) - IG

def compute_best_attribute(X, y):

""" Compute attribute with maximum information gain """

best_gain = float('-inf')

best_attr = 0

for idx in range(X.shape[1]):

gain = compute_info_gain(X[:, idx], y)

if gain > best_gain:

best_gain = gain

best_attr = idx

return best_attr, best_gain

def print_ID3(tree):

# 如果是叶子,打印class_value以及class_distribution

# 否则打印attribute

if tree.successors:

print("attribute: {0}, y = {1}".format(

attribute_names[tree.attribute], tree.class_distribution))

for t in tree.successors:

print_ID3(t)

else:

print("leaf: y = {0}".format(tree.class_distribution))

if __name__ == '__main__':

X, y = zip(*inputs)

X = pd.DataFrame.from_records(X).values

attribute_names = ['level', 'lang', 'tweets', 'phd']

y = np.array(y)

compute_best_attribute(X, y)

tree = ID3()

tree.buildTree(X, y)

print_ID3(tree)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值