Pytorch之KLDivLoss

理论基础

KL散度:衡量两个概率分布之间的相似性,其值越小,概率分布越接近。公式表达如下。

D K L ( P ∥ Q ) = ∑ i = 1 N [ p ( x i ) log ⁡ p ( x i ) − p ( x i ) log ⁡ q ( x i ) ] = ∑ i = 1 N [ p ( x i ) log ⁡ p ( x i ) log ⁡ q ( x i ) ] \begin{aligned} D_{K L}(P \| Q) & =\sum_{i=1}^{N}\left[p\left(x_{i}\right) \log p\left(x_{i}\right)-p\left(x_{i}\right) \log q\left(x_{i}\right)\right] \\ & = \sum_{i=1}^{N}\left[p\left(x_{i}\right) \frac{\log p\left(x_{i}\right)}{\log q\left(x_{i}\right)} \right] \end{aligned} DKL(PQ)=i=1N[p(xi)logp(xi)p(xi)logq(xi)]=i=1N[p(xi)logq(xi)logp(xi)]

注:对于两个概率分布 P P P Q Q Q P P P 为真实事件的概率分布, Q Q Q 为随机事件拟合出来的该事件的概率分布,即 D K L ( P ∥ Q ) D_{K L}(P \| Q) DKL(PQ) 表示使用 P P P 来拟合 Q Q Q, 或者说使用 Q Q Q 来指导 P P P

实现

import torch
import torch.nn as nn
import torch.nn.functional as F

# 预测值
input = torch.tensor([0.7, .1, .2], requires_grad=True)  # dim=0 每一行为一个样本

# 真实值
target = torch.tensor([.2, .5, .3])

# 计算KL散度
# 方式1
kl_loss = nn.KLDivLoss(reduction="batchmean")
output = kl_loss(F.log_softmax(input, dim=0), F.softmax(target, dim=0))
print(output)

# 方式2
print(F.kl_div(F.log_softmax(input, dim=0), F.softmax(target, dim=0), reduction="batchmean"))

# 方式3
my_kl_loss = F.softmax(target, dim=0) * (torch.log(F.softmax(target, dim=0)) - F.log_softmax(input, dim=0))
my_kl_loss = my_kl_loss.mean()
print(my_kl_loss)

# 方式4
my_kl_loss2 = F.softmax(target, dim=0) * (F.log_softmax(target, dim=0) - F.log_softmax(input, dim=0))
my_kl_loss2 = my_kl_loss2.mean()
print(my_kl_loss2)

# ----------------输出--------------------
# tensor(0.0239, grad_fn=<DivBackward0>)
# tensor(0.0239, grad_fn=<MeanBackward0>)
# tensor(0.0239, grad_fn=<MeanBackward0>)
# tensor(0.0239, grad_fn=<DivBackward0>)
# ----------------------------------------

几个要点

  1. KL散度的原理
  2. KL实现为什么要做log和softmax
  3. 上溢出和下溢出的情况
  4. 在pytorch的log函数中,默认是以 e e e 为底数的

参考:

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值