from sklearn import metrics
import math
import numpy as np
from sklearn import metrics
from sklearn import metrics
def NMI(A,B): #此为计算互信息的函数,还可以计算标准互信息
#样本点数
total = len(A)
A_ids = set(A)
B_ids = set(B)
#互信息计算
MI = 0
eps = 1.4e-45
for idA in A_ids:
for idB in B_ids:
idAOccur = np.where(A==idA)
idBOccur = np.where(B==idB)
idABOccur = np.intersect1d(idAOccur,idBOccur)
px = 1.0*len(idAOccur[0])/total
py = 1.0*len(idBOccur[0])/total
pxy = 1.0*len(idABOccur)/total
MI = MI + pxy*math.log(pxy/(px*py)+eps,2)
# 标准化互信息
Hx = 0
for idA in A_ids:
idAOccurCount = 1.0*len(np.where(A==idA)[0])
Hx = Hx - (idAOccurCount/total)*math.log(idAOccurCount/total+eps,2)
Hy = 0
for idB in B_ids:
idBOccurCount = 1.0*len(np.where(B==idB)[0])
Hy = Hy - (idBOccurCount/total)*math.log(idBOccurCount/total+eps,2)
MIhat = 2.0*MI/(Hx+Hy)
return MI
if __name__ == '__main__':
dic = {1: {'w7w93t': 0.6901, 'w7w93y': 0.114, 'w7w93w': 0.0887, 'w7w93q': 0.0439, 'w7w93p': 0.0263, 'w7w93v': 0.0253, 'w7w93m': 0.0058, 'w7w93k': 0.0019, 'w7w93n': 0.0019, 'w7w932': 0.001, 'w7w93j': 0.001}, 2: {'w7w93t': 0.6684, 'w7w93y': 0.1265, 'w7w93w': 0.0961, 'w7w93q': 0.048, 'w7w93v': 0.0248, 'w7w93p': 0.0227, 'w7w93m': 0.0062, 'w7w93n': 0.0031, 'w7w93k': 0.001, 'w7w93j': 0.001, 'w7w93s': 0.001, 'w7w932': 0.0005, 'w7w93z': 0.0005}, 3: {'w7w93t': 0.6798, 'w7w93y': 0.1168, 'w7w93w': 0.0913, 'w7w93q': 0.049, 'w7w93v': 0.0266, 'w7w93p': 0.0232, 'w7w93m': 0.0068, 'w7w93n': 0.0024, 'w7w93k': 0.001, 'w7w93s': 0.001, 'w7w93j': 0.0007, 'w7w93z': 0.0007, 'w7w932': 0.0003, 'w7w931': 0.0003}, 4: {'w7w93t': 0.6875, 'w7w93y': 0.1077, 'w7w93w': 0.0952, 'w7w93q': 0.0445, 'w7w93v': 0.0296, 'w7w93p': 0.0236, 'w7w93m': 0.0062, 'w7w93n': 0.0022, 'w7w93s': 0.001, 'w7w93k': 0.0007, 'w7w93j': 0.0005, 'w7w93z': 0.0005, 'w7w93x': 0.0002, 'w7w932': 0.0002, 'w7w931': 0.0002}, 5: {'w7w93t': 0.6834, 'w7w93y': 0.1093, 'w7w93w': 0.0953, 'w7w93q': 0.0455, 'w7w93v': 0.029, 'w7w93p': 0.0248, 'w7w93m': 0.0062, 'w7w93n': 0.0028, 'w7w93j': 0.0008, 'w7w93s': 0.0008, 'w7w93z': 0.0006, 'w7w93k': 0.0006, 'w7w932': 0.0004, 'w7w931': 0.0002, 'w7w93x': 0.0002, 'w7w930': 0.0002}}
i = 1
while i < len(dic.keys()) + 1: # 把90天的数据分离出来成为字典
for k in dic[5].keys(): # 比较两个字典,没有的key的加上该key并取值为0
if k not in dic[i].keys():
dic[i][k] = 0
dic[i] = {j: dic[i][j] for j in sorted(dic[i].keys())} # 对dic[i]字典按照key值排序后输出
i += 1
print('90天的数据处理排序后是')
print(dic)
print('*****开始计算互信息**********')
c=1
result1=[]#result存放最终对比的NMI结果
result2=[]
while c < len(dic.keys()) + 1:
A=list(dic[c].values())
B=list(dic[5].values()) #B是最终的那个数据
C = np.array(A)
D = np.array(B)
result_NMI1 = NMI(C, D) #计算互信息
result_NMI2 = metrics.normalized_mutual_info_score(A, B)
result1.append(result_NMI1)
result2.append(result_NMI2)
c+=1
print(result1)
print(result2)
根据字典数据计算互信息
最新推荐文章于 2023-10-16 21:15:26 发布