计算信息增益

# -*- coding: UTF-8 -*-
from math import log
from collections import Counter
import csv
import numpy as np
 

def createDataSet():
    dataSet = np.array([['年龄', '有工作', '有自己的房子', '信贷情况','vqa'] ,
               [0, 0, 0, 0, 'no'], 
               [0, 0, 0, 1, 'no'],
               [0, 1, 0, 1, 'yes'],
               [0, 1, 1, 0, 'yes'],
               [0, 0, 0, 0, 'no'],
               [1, 0, 0, 0, 'no'],
               [1, 0, 0, 1, 'no'],
               [1, 1, 1, 1, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [2, 0, 1, 2, 'yes'],
               [2, 0, 1, 1, 'yes'],
               [2, 1, 0, 1, 'yes'],
               [2, 1, 0, 2, 'yes'],
               [2, 0, 0, 0, 'no']])
    return dataSet
 
def calcShannonEnt(dataSet,axis=-1):                      
    numEntires = len(dataSet) 
    columnCounter = Counter(dataSet[:,axis])
    shannonEnt = 0.0                                
    for key in columnCounter.keys():                       
        prob = float(columnCounter[key]) / numEntires  
        shannonEnt -= prob * log(prob, 2)           
    return shannonEnt                           

def subDataSet(dataSet, axis, value):
    numEntires = len(dataSet) 
    subDataSetIndexs = np.where(dataSet[:,axis]==value)
    subDataSet = dataSet[subDataSetIndexs,:]
    subDataSet = subDataSet[0]
    return subDataSet                                   
 
def entropyGain(dataSet,axis=1,baseAxis=-1):
    numEntires = len(dataSet) 
    EntropyGain = calcShannonEnt(dataSet,baseAxis)
    columnCounter = Counter(dataSet[:,axis])
    newEntropy = 0.0
    for key in columnCounter.keys():
        prob = float(columnCounter[key]) / numEntires
        subSet = subDataSet(dataSet=dataSet,axis=axis,value=key)
        newEntropy += prob * calcShannonEnt(dataSet=subSet,axis=baseAxis)
    EntropyGain = EntropyGain - newEntropy
    return EntropyGain
 
 
dataset = createDataSet()
dataset = dataset[1:,:]
Entropy=calcShannonEnt(dataset)
print('Entropy is {Entropy:0.6f}'.format(Entropy=Entropy))    


for i in range(4):
    EntropyGain = entropyGain(dataset,axis=i,baseAxis=4)
    print('EntropyGain is {EntropyGain}'.format(EntropyGain=EntropyGain))    

Entropy is 0.970951
EntropyGain is 0.0830074998558
EntropyGain is 0.323650198152
EntropyGain is 0.419973094022
EntropyGain is 0.362989562537

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值