# coding=utf-8
import math
'''
CART决策树模型,假设有三个条件
年龄,有三个选项 1 表示老年人 2 表示中年 人 3 表示青年人
工作,有两个选项 1 表示有工作 2表示 没有工作
房子,有两个选项 2 表示有房子 2表示 没有房子
信贷情况 1表示一般 2表示号 3表示非常好
输出,有两个选项 1 表示给贷款 2表示 不予贷款
要求: 依次获得每个选项的信息增益
'''
min_row_count = 2
g_desc = [
{1: "老年人", 2: "中年人", 3: "青年人"},
{1: "有工作", 2: "没工作"},
{1: "有房子", 2: "没有房子"},
{1: "信贷一般", 2: "信贷好", 3: "信贷非常好"}
]
g_columns = ["年龄", "工作", "房子", "信贷"]
sample_input = [[3, 2, 2, 1, 2],
[3, 2, 2, 2, 2],
[3, 1, 2, 2, 1],
[3, 1, 1, 1, 1],
[3, 2, 2, 1, 2],
[2, 2, 2, 1, 2],
[2, 2, 2, 2, 2],
[2, 1, 1, 2, 1],
[2, 2, 1, 3, 1],
[2, 2, 1, 3, 1],
[1, 2, 1, 3, 1],
[1, 2, 1, 2, 1],
[1, 1, 2, 2, 1],
[1, 1, 2, 3, 1],
[1, 2, 2, 1, 2]]
class Node:
def __init__(self, index, value, jini, dataset_1, dataset_2):
self.index = index
self.value = value
self.jini = jini
self.dataset_1 = dataset_1
self.dataset_2 = dataset_2
self.left = None
self.right = None
def set_ds1(self, ds1):
self.dataset_1 = ds1
def set_ds2(self, ds2):
self.dataset_2 = ds2
def set_left(self, left):
self.left = left
def set_right(self, right):
self.right = right
# 计算 right_index列的基尼值
def get_proper_tezheng(sample_input, right_index):
row_count = len(sample_input)
distinct_columns = {}
for row in sample_input:
current = row[right_index]
if current not in distinct_columns.keys():
distinct_columns[current] = {}
distinct_columns[current]['jini'] = 0
# 计算得到不同特征值的基尼指数
min_element = 0
min_corresp = 9
for item in distinct_columns.keys():
dataset_1 = []
dataset_2 = []
for row in sample_input:
if item == row[right_index]:
dataset_1.append(row)
else:
dataset_2.append(row)
# 计算两个数据集的基尼值
distinct_columns[item]['jini'] = (len(dataset_1) / row_count) * get_gini(dataset_1) + (
len(dataset_2) / row_count) * get_gini(dataset_2)
if distinct_columns[item]['jini'] < min_corresp:
min_corresp = distinct_columns[item]['jini']
min_element = item
# 返回最优切分店
result = Node(right_index, min_element, min_corresp, [], [])
dataset_1 = []
dataset_2 = []
for item in sample_input:
if item[right_index] == min_element:
dataset_1.append(item)
else:
dataset_2.append(item)
result.set_ds1(dataset_1)
result.set_ds2(dataset_2)
return result
def get_finall_node(sample_input):
columns_count = len(sample_input[0])
finall_result = Node(0, 0, 1, [], [])
for i in range(0, columns_count - 1):
current = get_proper_tezheng(sample_input, i)
# print(current.index)
# print(current.jini)
if current.jini < finall_result.jini:
finall_result = current
#print(finall_result.index)
#print(finall_result.jini)
return finall_result
# 获得一个矩阵的基尼值
def get_gini(dataset):
result = {}
column_count = len(dataset)
# 获得当前矩阵的输出结果,和每个结果的值的数量
for item in dataset:
output = item[len(item) - 1]
if output not in result:
result[output] = 1
else:
result[output] = result[output] + 1
sum_fenshu = 0
for item in result.keys():
sum_fenshu = sum_fenshu + (result[item] / column_count) * (result[item] / column_count)
return 1 - sum_fenshu
def proc_node(current_node):
if len(current_node.dataset_1) <= min_row_count or len(current_node.dataset_2) <= min_row_count:
return current_node
# 设置左右子节点
current_node.set_left(get_finall_node(current_node.dataset_1))
current_node.set_right(get_finall_node(current_node.dataset_2))
def bianli_node(root):
print(str(root.jini) + " " + g_columns[root.index])
if root.left:
bianli_node(root.left)
if root.right:
bianli_node(root.right)
root = get_finall_node(sample_input)
proc_node(root)
bianli_node(root)