最近在看决策树,在B站上看到了一个前辈的讲课视频
讲的非常详细,于是自己手动实现了一下基于ID3的决策树
说来惭愧,我是新手,所以并没有导包,纯原始python写的。也并没有画出最后决策树的构建图。
只是让我对这个决策树更加了解一些,后续学到引入外部包,再说。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import math
import numpy as np
# 训练集--构建决策树
data = [
['Sunny', 'Hot', 'High', 'Weak', 'No'],
['Sunny', 'Hot', 'High', 'Strong', 'No'],
['Overcast', 'Hot', 'High', 'Weak', 'Yes'],
['Rain', 'Mild', 'High', 'Weak', 'Yes'],
['Rain', 'Cool', 'Normal', 'Weak', 'Yes'],
['Rain', 'Cool', 'Normal', 'Strong', 'No'],
['Overcast', 'Cool', 'Normal', 'Strong', 'Yes'],
['Sunny', 'Mild', 'High', 'Weak', 'No'],
['Sunny', 'Cool', 'Normal', 'Weak', 'Yes'],
['Rain', 'Mild', 'Normal', 'Weak', 'Yes'],
['Sunny', 'Mild', 'Normal', 'Strong', 'Yes'],
['Overcast', 'Mild', 'High', 'Strong', 'Yes'],
['Overcast', 'Hot', 'Normal', 'Weak', 'Yes'],
['Rain', 'Mild', 'High', 'Strong', 'No']
]
columns = ['Outlook', 'Temperature', 'Humidity', 'Wind']
columns_index = {
'Outlook': 0,
'Temperature': 1,
'Humidity': 2,
'Wind': 3,
}
# 第1步计算决策属性的熵
def calculate_entropy(path):
# print('path', path)
decision_entropy = 0
decision_calculate = {}
filtered_data = []
for line in data:
# if满足条件
satisfy = True
for column in path:
if path[column] != line[columns_index[column]]:
satisfy = False
break
if satisfy:
filtered_data.append(line)
# print(filtered_data)
for line in filtered_data:
count = decision_calculate.get(line[-1])
if count is None:
count = 0
count += 1
decision_calculate[line[-1]] = count
# print(decision_calculate)
if len(filtered_data) > 0:
for decision in decision_calculate:
decision_calculate[decision] /= len(filtered_data) * 1.0
decision_entropy -= decision_calculate[decision] * math.log(decision_calculate[decision], 2)
return decision_entropy, filtered_data
# 第2步计算条件属性的熵
# 条件属性共有4个:
# Outlook、 Temperature、 Humidity、 Wind。
# 分别计算不同属性的信息增益。
# 计算Outlook中各个属性的条件熵
# Outlook共分三个组:
# Sunny(D1)、Overcast(D2)、 Rain(D3)
# Sunny
def child_node(parent_score, nodes, node_data, path):
# node_data 根据 node 分组
node_dict = {}
for line in node_data:
for node in nodes:
attributes = node_dict.get(node)
if attributes is None:
attributes = {}
attribute = attributes.get(line[columns_index[node]])
if attribute is None:
attribute = {}
num = attribute.get(line[-1])
if num is None:
num = 1
else:
num += 1
attribute[line[-1]] = num
attribute['count'] = 1 if attribute.get('count') is None else attribute.get('count') + 1
attributes[line[columns_index[node]]] = attribute
node_dict[node] = attributes
# print(node_dict)
# 计算Outlook中各个属性的条件熵
root = next(iter(node_dict))
increment = 0
for node in node_dict:
node_score = []
weight = []
# print('node :', node, end='')
for attribute in node_dict[node]:
# print(' attribute :', attribute, end='')
current_path = path.copy()
current_path[node] = attribute
decision_entropy, filtered_data = calculate_entropy(current_path)
# print(' attribute_score :', decision_entropy)
node_score.append(decision_entropy)
weight.append((node_dict[node][attribute]['count'] / len(node_data)))
# print('node_score', node_score)
# print('weight', weight)
node_score = sum(np.multiply(node_score, weight))
# print('node_score', node_score)
if parent_score - node_score > increment:
increment = parent_score - node_score
root = node
# print('increment', parent_score - node_score)
# print('choose :', root, ' increment:', increment)
return root, node_dict[root]
def find_attribute(root, attributes, path, tree_node):
# print('choose', root, 'attributes', attributes)
# print('tree_node', tree_node.name, 'root', root)
for attribute in attributes:
# print('attribute... : ', attribute)
path[root] = attribute
# find node
entropy, node_data = calculate_entropy(path)
# print('filtered_data', node_data)
# print('entropy', entropy)
attribute_node = Node(attribute, tree_node, [])
tree_node.next.append(attribute_node)
if entropy == 0.0:
attribute_node.next.append(Node(node_data[-1][-1], next, None))
# print(node_data[-1][-1])
elif len(node_data) > 0:
node, attributes = child_node(entropy, columns, node_data, path)
# find attribute
temp_node = Node(node, next, [])
attribute_node.next.append(temp_node)
# print('choose', node, 'attributes', attributes)
find_attribute(node, attributes, path, temp_node)
path.pop(root)
class Node:
def __init__(self, name, before, next):
self.next = next
self.before = before
self.name = name
# 决策树构建完成后,进行预测
def predict(root, line):
if root.next is None:
print('result', root.name)
column = root.name
attribute = line.get(column)
if attribute is not None:
for next in root.next:
if next.name == attribute:
for next2 in next.next:
predict(next2, line)
if __name__ == '__main__':
# {'Outlook': 'Sunny', 'Temperature': 'Hot'}
path = {}
# find node
entropy, node_data = calculate_entropy(path)
root, attributes = child_node(entropy, columns, node_data, path)
head = Node(root, None, [])
# find attribute
find_attribute(root, attributes, {}, head)
test_data = [
{
'Temperature': 'Hot',
'Humidity': 'High',
'Wind': 'Weak',
'Outlook': 'Sunny'
},
{
'Outlook': 'Overcast',
'Temperature': 'Hot',
'Humidity': 'High',
'Wind': 'Weak'
},
]
for line in test_data:
print(line)
predict(head, line)