【KL散度】stats.entropy、special.rel_entr、special.kl_div、F.kl_div与nn.KLDivLoss用法解析

  偶然学习KL散度,突然发现python里面KL散度的实现有很多种耶,一时就地懵圈,各处查阅资料,终于理解了,主要从代码实现和公式的角度,整理记录一下神奇的stats.entropy、special.rel_entr、special.kl_div、F.kl_div与nn.KLDivLoss吧。
  


1. KL散度

  KL散度(Kullback-Leibler divergence)用于度量两个概率分布的相似度,可作为经典损失函数,设有 P {P} P 为真实分布, Q {Q} Q 为近似分布,若为离散随机变量,则公式表示为:

D K L ( P ∣ ∣ Q ) = ∑ i P ( i ) I n ( P ( i ) Q ( i ) ) {D_{KL}}\left( {P||Q} \right) = \sum\limits_i {P\left( i \right)} {\mathop{\rm In}\nolimits} \left( {\frac{{P\left( i \right)}}{{Q\left( i \right)}}} \right) DKL(P∣∣Q)=iP(i)In(Q(i)P(i))   若为连续随机变量,则公式表示为:
D K L ( P ∣ ∣ Q ) = ∫ − ∞ ∞ p ( x ) I n ( P ( x ) Q ( x ) ) d x {D_{KL}}\left( {P||Q} \right) = \int_{ - \infty }^\infty {p\left( x \right)} {\mathop{\rm In}\nolimits} \left( {\frac{{P\left( x \right)}}{{Q\left( x \right)}}} \right)dx DKL(P∣∣Q)=p(x)In(Q(x)P(x))dx   KL散度要求输入的概率分布之和为1,因此在实际计算时,需要确保概率分布满足这个条件。另外,KL散度并不是一个对称函数,即 D K L ( P ∣ ∣ Q ) {{D_{KL}}\left( {P||Q} \right)} DKL(P∣∣Q) 不等于 D K L ( Q ∣ ∣ K ) {{D_{KL}}\left( {Q||K} \right)} DKL(Q∣∣K)


2. stats.entropy

  stats.entropy(官方文档)可以计算香农熵也可以计算相对熵,即KL散度,其包括4个参数,分布pk和qk,对数底base(默认为e)以及计算维度axis(默认为0):

scipy.stats.entropy(pk, qk=None, base=None, axis=0)

  其公式表示为:
e n t r o p y ( x , y ) = ( ∑ x log ⁡ ( x / y ) ) / log ⁡ ( b a s e ) {\mathrm{entropy}(x, y) =(\sum x \log(x / y))/ \log(base)} entropy(x,y)=(xlog(x/y))/log(base)   python中实现为:

import numpy as np
import scipy.stats

p = [0.1, 0.2, 0.3, 0.6]  
q = [0.2, 0.2, 0.2, 0.2] 

out0 = scipy.stats.entropy(p)
out1 = scipy.stats.entropy(p, q)  
out2 = scipy.stats.entropy(q, p)
out3 = scipy.stats.entropy(p, q, base=2)
print("p的香农熵:", out0)
print("p和q的相对熵:", out1)
print("不对称性验证:", out2)
print("base参数:", out3)
print("base参数验证:", out1 / np.log(2))

# 归一化---------------------------------------------
pp = p / np.sum(p)
qq = q / np.sum(q)
# ---------------------------------------------------
zz = []
for i in range(4):
    temp = -pp[i] * np.log(pp[i])
    zz.append(temp)
print("手动计算香农熵:", np.sum(zz))

xx = []
for i in range(4):
    temp = pp[i] * np.log(pp[i] / qq[i])
    xx.append(temp)
print("手动计算相对熵:", np.sum(xx))

  值得注意的是,手动计算的时候,需要先将p和q规范化,使得其元素和为1,而stats.entropy函数自动实现了这一步。由结果可知,stats.entropy与手动计算的输出是一致的,KL散度不具有对称性:

p的香农熵: 1.1988493129136213
p和q的相对熵: 0.18744504820626923
不对称性验证: 0.20273255405408233
base参数: 0.27042604148637733
base参数验证: 0.27042604148637733
手动计算香农熵: 1.1988493129136213
手动计算相对熵: 0.18744504820626923

3. special.rel_entr

  special.rel_entr(官方文档)也可以计算KL散度,其公式表示为:
r e l _ e n t r ( x , y ) = { x log ⁡ ( x / y ) x > 0 , y > 0 0 x = 0 , y ≥ 0 ∞ otherwise {\mathrm{rel\_entr}(x, y) = \begin{cases} x \log(x / y) & x > 0, y > 0 \\ 0 & x = 0, y \ge 0 \\ \infty & \text{otherwise} \end{cases}} rel_entr(x,y)= xlog(x/y)0x>0,y>0x=0,y0otherwise   python中实现为:

import numpy as np
import scipy.special

p = [0.1, 0.2, 0.3, 0.6]  
q = [0.2, 0.2, 0.2, 0.2] 

out1 = scipy.stats.entropy(p, q)  
print("stats.entropy:", out1)

# 归一化---------------------------------------------
pp = p / np.sum(p)
qq = q / np.sum(q)
# ---------------------------------------------------

out2 = scipy.special.rel_entr(pp, qq)
print("special.rel_entr:", np.sum(out2))
print("元素输出:", out2)

  值得注意的是special.rel_entr的输入是规范化之后的分布,而非原始分布,special.rel_entr的直接输出是元素输出,求和之后,special.rel_entr与stats.entropy计算结果一致:

stats.entropy: 0.18744504820626923
special.rel_entr: 0.18744504820626923
元素输出: [-9.15510241e-02 -6.75775180e-02 -5.55111512e-17  3.46573590e-01]

4. special.kl_div

  special.kl_div(官方文档)从名字看似乎是最正宗的KL散度,但其公式表示为:
k l _ d i v ( x , y ) = { x log ⁡ ( x / y ) − x + y x > 0 , y > 0 y x = 0 , y ≥ 0 ∞ otherwise {\mathrm{kl\_div}(x, y) = \begin{cases} x \log(x / y) - x + y & x > 0, y > 0 \\ y & x = 0, y \ge 0 \\ \infty & \text{otherwise} \end{cases}} kl_div(x,y)= xlog(x/y)x+yyx>0,y>0x=0,y0otherwise   python中实现为:

import numpy as np
import scipy.special
import scipy.stats

p = [0.1, 0.2, 0.3, 0.6]  
q = [0.2, 0.2, 0.2, 0.2] 

# 归一化---------------------------------------------
pp = p / np.sum(p)
qq = q / np.sum(q)
# ---------------------------------------------------

out1 = scipy.stats.entropy(p, q)  
print("stats.entropy:", out1)

out2 = scipy.special.rel_entr(pp, qq)
print("special.rel_entr:", np.sum(out2))
print("special.rel_entr元素:", out2)

out3 = scipy.special.kl_div(pp, qq)
print("special.kl_div:", np.sum(out3))
print("special.kl_div元素:", out3)

xx = []
for i in range(4):
    temp = (pp[i] * np.log(pp[i] / qq[i])) - pp[i] + qq[i]
    xx.append(temp)
print("手动special.kl_div计算:", np.sum(xx))
print("手动special.kl_div元素:", xx)

  与special.rel_entr相同,special.kl_div需要输入规范化之后的分布,也是元素输出,求和之后,special.kl_div、special.rel_entr和stats.entropy计算结果一致,不同的是,special.kl_div的元素与special.rel_entr不一样,因为每一项多了 − x + y {-x+y} x+y ,但因输入分布经过了规范化,故求和后值相同:

stats.entropy: 0.18744504820626923
special.rel_entr: 0.18744504820626923
special.rel_entr元素: [-9.15510241e-02 -6.75775180e-02 -5.55111512e-17  3.46573590e-01]
special.kl_div: 0.1874450482062694
special.kl_div元素: [0.07511564 0.01575582 0.         0.09657359]
手动special.kl_div计算: 0.1874450482062694
手动special.kl_div元素: [0.07511564261099085, 0.015755815315305954, 0.0, 0.09657359027997259]

5. F.kl_div与nn.KLDivLoss

  F.kl_div(官方文档)和nn.KLDivLoss(官方文档)是torch中实现的KL散度损失函数,设 y pred {y_{\text{pred}}} ypred 为模型预测输出, y true {y_{\text{true}}} ytrue 为真实分布,公式表达为:
L ( y pred ,   y true ) = y true ⋅ log ⁡ y true y pred = y true ⋅ ( log ⁡ y true − log ⁡ y pred ) {L(y_{\text{pred}},\ y_{\text{true}}) = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}} = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}})} L(ypred, ytrue)=ytruelogypredytrue=ytrue(logytruelogypred)   但他们实际实现的时候与上述公式有些许的不同,主要由参数 log_target 控制。

  当 log_target=False 时,公式为:
L ( y pred ,   y true ) = y true ⋅ ( log ⁡ y true − y pred ) {L(y_{\text{pred}},\ y_{\text{true}}) = y_{\text{true}} \cdot (\log y_{\text{true}} - y_{\text{pred}})} L(ypred, ytrue)=ytrue(logytrueypred)   当 log_target=True 时,公式为:
L ( y pred ,   y true ) = e y true ⋅ ( y true − y pred ) {L(y_{\text{pred}},\ y_{\text{true}}) = e^{y_{\text{true}}} \cdot (y_{\text{true}} - y_{\text{pred}})} L(ypred, ytrue)=eytrue(ytrueypred)   值得注意的是,F.kl_div与nn.KLDivLoss输入的第一个分布为对数概率分布( y pred y_{\text{pred}} ypred ),第二个分布为概率分布( y true y_{\text{true}} ytrue ),由于 y pred y_{\text{pred}} ypred 已经是对数了,所以当log_target=False时, y pred y_{\text{pred}} ypred 那没有取对数,而 y true y_{\text{true}} ytrue 那取了对数。

  python中实现为:

import torch
import torch.nn as nn
import torch.nn.functional as F

x = torch.tensor([0.1, 0.2, 0.3, 0.6])
y = torch.tensor([0.2, 0.2, 0.2, 0.2])

logp_x = F.log_softmax(x, dim=-1)  # torch.log(F.softmax(x, dim=-1))
p_y = F.softmax(y, dim=-1)  # [0.25, 0.25, 0.25, 0.25] 

kl_sum = F.kl_div(logp_x, p_y, reduction='sum')
kl_mean = F.kl_div(logp_x, p_y, reduction='batchmean')
print(kl_sum, kl_mean)

kl_sum_log_target = F.kl_div(logp_x, p_y, reduction='sum', log_target=True)
kl_mean_log_target = F.kl_div(logp_x, p_y, reduction='batchmean', log_target=True)
print(kl_sum_log_target, kl_mean_log_target)

kl_loss_sum = nn.KLDivLoss(reduction="sum")
output1 = kl_loss_sum(logp_x, p_y)
kl_loss_mean = nn.KLDivLoss(reduction="batchmean")
output2 = kl_loss_mean(logp_x, p_y)  # logp_x:pred, p_y:target/true
print(output1, output2)

xx = []
for i in range(4):
    temp = p_y[i] * (p_y[i].log() - logp_x[i])
    xx.append(temp)
print("log_target=False, 手动计算KL_loss:", sum(xx), (xx[0]+xx[1]+xx[2]+xx[3])/4)

yy = []
for i in range(4):
    temp = p_y[i].exp() * (p_y[i] - logp_x[i])
    yy.append(temp)
print("log_target=True, 手动计算KL_loss:", sum(yy), (yy[0]+yy[1]+yy[2]+yy[3])/4)

  验证了手动计算与F.kl_div与nn.KLDivLoss的输出一致:

tensor(0.0182) tensor(0.0045)
tensor(8.4976) tensor(2.1244)
tensor(0.0182) tensor(0.0045)
log_target=False, 手动计算KL_loss: tensor(0.0182) tensor(0.0045)
log_target=True, 手动计算KL_loss: tensor(8.4976) tensor(2.1244)
  • 24
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值