互信息

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numexpr as en
sns.set()

X = np.array([1,0,1,1,0])
Y = np.array([1,1,1,0,0])

def checkType(X):
    if isinstance(X,(np.ndarray,list,tuple)):
        return torch.FloatTensor(X)
    elif isinstance(X,(torch.TensorType,torch.FloatType)):
        return X
    else:
        print("Type Error")

def NumpyProb(X,symbol,x):
    n = X.size
    Lambda = "{}{}{}".format("X",symbol,"x")
    expr = en.evaluate(Lambda)
    return (expr).dot(np.ones(n))/n


def NumpyEntropy(X):
    NE = 0
    for x in np.unique(X):
        PX = NumpyProb(X,'==',x)
        NE += (- PX * np.log2(PX))
    return NE
NumpyEntropy(X)


def TorchEntropy(X):
    NE = 0
    init_X = checkType(X)
    m = torch.tensor(init_X.size())[0].float()
    for x in torch.unique(init_X):
        PX = (init_X == x).sum().float()/m
        NE += (- PX * np.log2(PX))
    return NE
TorchEntropy(X)


def NumpyJointEntropy(X,Y):
    Xn,Yn = X.size,Y.size
    init_XY = np.vstack((X,Y)).T
    m,n = init_XY.shape
    enumerate_ = np.array([(x,y) for x in np.unique(X) for y in np.unique(Y)])
    temp = np.array([((init_XY==e).dot(np.ones(n))==n).dot(np.ones(m))/m for e in enumerate_])
    return (-temp*np.log2(temp)).sum()
NumpyJointEntropy(X,Y)


def TorchJointEntropy(X,Y):
    m = X.size
    comb = lambda SET : np.array([[i,j] for i in SET[::-1] for j in SET])
    element = checkType(comb(np.union1d(np.unique(X),np.unique(Y))))
    XY = torch.cat([checkType(X)[:,None],checkType(Y)[:,None]],1)
    PXY = torch.tensor([(((XY == e).float()).mm(torch.ones((2,1)))==2).sum().float()/m for e in element])
    return (-PXY*torch.log2(PXY)).sum()
TorchJointEntropy(X,Y)


def MutualInfo(X,Y,keyelement="X<->Y",keyType='numpy'):
    if keyType == 'numpy':
        HX,HY,HXY = NumpyEntropy(X),NumpyEntropy(Y),NumpyJointEntropy(X,Y)
    elif keyType == "torch":
        HX,HY,HXY =  TorchEntropy(X), TorchEntropy(Y),TorchJointEntropy(X,Y)    
    return {"X<-Y":HX-HXY,"X->Y":HY-HXY,"X<->Y":HX+HY-HXY}[keyelement]

在这里插入图片描述

MutualInfo(X,Y,keyelement='X<->Y',keyType='torch')

在这里插入图片描述

MutualInfo(X,Y,keyelement='X->Y',keyType='torch')

在这里插入图片描述

MutualInfo(X,Y,keyelement='X<-Y',keyType='torch')
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值