Pytorch复习笔记--KL散度的计算

1--KL散度

        KL 散度可用于衡量两个概率分布之间的相似性,两者越相似,其 KL 散度越小;

        KL 散度可用于多模态对比学习当中,用于比较两个模态之间的相似性;

2--计算公式

         P 表示真实概率分布,Q表示预测概率分布;

3--Pytorch实现

        Pytorch 通过 torch.nn.KLDivLoss() 计算 KL 散度:

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

if __name__ == "__main__":
    P = [0.4, 0.6] # 真实概率
    Q = [0.3, 0.7] # 预测概率

    # Manual
    KL1 = (0.4*np.log(0.4) - 0.4*np.log(0.3)) + (0.6*np.log(0.6) - 0.6*np.log(0.7)) # 0.226
    print("KL1: ", KL1)

    Cal_KL2 = nn.KLDivLoss()
    PP = F.softmax(torch.tensor(P), -1)
    QQ = F.log_softmax(torch.tensor(Q), -1)
    KL2 = Cal_KL2(QQ, PP)
    print("KL2: ", KL2)
    print("All Done !")

        torch.nn.KLDivLoss() 的第一个参数为预测概率分布 Q,第二个参数为真实概率分布 P;

        一般情况下,Q 和 P 都需要经过 softmax() 操作以保证概率和 1,Q还需要进行 log() 操作;

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值