决策树-信息增益的计算

import numpy as np
import pandas as pd

from collections import Counter
import math
from math import log

# 熵
# print(-(1 / 3) * log(1 / 3, 2) - (2 / 3) * log(2 / 3, 2))


def calc_ent(datasets):
    data_length = len(datasets)
    label_count = {}
    for i in range(data_length):
        label = datasets[i][-1]
        if label not in label_count:
            label_count[label] = 0
        label_count[label] += 1
    ent = -sum([(p / data_length) * log(p / data_length, 2)
                for p in label_count.values()])
    # print(ent)
    return ent

# 经验条件熵


def cond_ent(datasets, axis=0):
    data_length = len(datasets)
    feature_sets = {}
    for i in range(data_length):
        feature = datasets[i][axis]
        if feature not in feature_sets:
            feature_sets[feature] = []
        feature_sets[feature].append(datasets[i])
    cond_ent = sum([(len(p) / data_length) * calc_ent(p)
                    for p in feature_sets.values()])
    print(cond_ent)
    return cond_ent

# 信息增益


def info_gain(ent, cond_ent):
    return ent - cond_ent


def info_gain_train(datasets):
    count = len(datasets[0]) - 1
    print(count)
    ent = calc_ent(datasets)
    print(ent)
    best_feature = []
    for c in range(count):
        c_info_gain = info_gain(ent, cond_ent(datasets, axis=c))
        best_feature.append((c, c_info_gain))
        print('特征({}) - info_gain - {:.3f}'.format(labels[c], c_info_gain))
    # 比较大小
    best_ = max(best_feature, key=lambda x: x[-1])
    return '特征({})的信息增益最大,选择为根节点特征'.format(labels[best_[0]])

# labels = ["天气", "温度", "湿度", "刮风", '类别']
# datasets = pd.DataFrame([
#     ["晴", "高", "中", "否", '否'],
#     ["晴", "高", "中", "是", '否'],
#     ["阴天", "高", "高", "否", '是'],
#     ["雨", "高", "高", "否", '是'],
#     ["雨", "低", "高", "否", '否'],
#     ["晴", "中", "中", "是", '是'],
#     ["阴天", "中", "高", "是", '否'],
# ])


labels = ["天气", "湿度", "刮风", '类别']
datasets = pd.DataFrame([
    ["晴", "中", "否", '否'],
    ["晴", "中", "是", '否'],
    ["阴天", "高", "否", '是'],
    ["雨", "高", "否", '是']
])


print(datasets)
print(info_gain_train(np.array(datasets)))
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值