pytorch中的kl divergence计算问题

偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中计算结果不同,平时没有注意到,记录下

一篇关于KL散度、JS散度以及交叉熵对比的文章

kl divergence 介绍

KL散度( Kullback–Leibler divergence),又称相对熵,是描述两个概率分布 P 和 Q 差异的一种方法。计算公式:
D ( P ∣ ∣ Q ) = ∑ p i   l o g 2 p i q i D(P||Q)=\sum{p_i~log_2\frac{p_i}{q_i}} D(PQ)=pi log2qipi
可以发现,P 和 Q 中元素的个数不用相等,只需要两个分布中的离散元素一致。
举个简单例子:
两个离散分布分布分别为 P 和 Q
P 的分布为:{1,1,2,2,3}
Q 的分布为:{1,1,1,1,1,2,3,3,3,3}
我们发现,虽然两个分布中元素个数不相同,P 的元素个数为 5,Q 的元素个数为 10。但里面的元素都有 “1”,“2”,“3” 这三个元素。
当 x = 1时,在 P 分布中,“1” 这个元素的个数为 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 这个元素的个数为 5,故 Q(x = 1) = 5/10 = 0.5
同理,
当 x = 2 时,P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1
当 x = 3 时,P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4
把上述概率带入公式:
D ( P ∣ ∣ Q ) = 0.4   l o g 2 0.4 0.5 + 0.4   l o g 2 0.4 0.1 + 0.2   l o g 2 0.2 0.4 = 0.47 D(P||Q)=0.4~log_2\frac{0.4}{0.5}+0.4~log_2\frac{0.4}{0.1}+0.2~log_2\frac{0.2}{0.4}=0.47 D(PQ)=0.4 log20.50.4+0.4 log20.10.4+0.2 log20.40.2=0.47
至此,就计算完成了两个离散变量分布的KL散度。

pytorch 中的 kl_div 函数

pytorch中有用于计算kl散度的函数 kl_div

torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')

在这里插入图片描述

计算 D (p||q)

1、不用这个函数的计算结果为:
在这里插入图片描述
与手算结果相同
2、使用函数:(这是计算正确的,结果有差异是因为pytorch这个函数中默认的是以e为底)
在这里插入图片描述
注意:
1、函数中的 p q 位置相反(也就是想要计算D(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log
2、reduction 是选择对各部分结果做什么操作,默认为取平均数,这里选择求和

好别扭的用法,不知道为啥官方把它设计成这样

  • 21
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
KL-DivergenceKL散度)是一种用来衡量两个概率分布之间差异的指标。它始终大于等于0,当且仅当两个分布完全相同时,KL散度等于0。KL散度具有非对称性,即DKL(P||Q)不等于DKL(Q||P),并且不满足三角不等式的形式,因此KL散度不是用来衡量距离的指标。 KL散度的公式可以表示为DKL(P||Q) = ΣP(x) * log(P(x) / Q(x)),其P和Q分别是两个概率分布,x表示分布的某个事件。这个公式可以用来计算P相对于Q的信息损失,或者可以理解为在使用Q来近似表示P时的额外损失。 在机器学习KL散度经常被用于衡量两个概率分布之间的差异,例如在概率生成模型和信息检索。在PyTorch,可以使用F.kl_div()函数来计算KL散度。这个函数的原型为F.kl_div(input, target, size_average=None, reduce=None, reduction='mean'),其input和target分别是输入和目标张量。 总结起来,KL散度是一种用来衡量两个概率分布之间差异的指标,它不是用来衡量距离的,并且具有非对称性。在机器学习KL散度常被用于衡量模型输出与真实分布之间的差异。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [浅谈KL散度](https://blog.csdn.net/weixin_33774615/article/details/85768162)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* *3* [KL Divergence ——衡量两个概率分布之间的差异](https://blog.csdn.net/weixin_42521185/article/details/124364552)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Systemd

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值