KL散度 pytorch实现

机器学习 同时被 2 个专栏收录
8 篇文章 0 订阅
2 篇文章 0 订阅

KL散度 KL Divergence

D K L D_{KL} DKL 是衡量两个概率分布之间的差异程度。

考虑两个概率分布 P P P, Q Q Q(譬如前者为模型输出data对应的分布,后者为期望的分布),则KL散度的定义如下:
D K L = ∑ x P ( x ) l o g P ( x ) Q ( x ) D_{KL} = \sum_xP(x)log\frac{P(x)}{Q(x)} DKL=xP(x)logQ(x)P(x)

D K L = ∫ x P ( x ) l o g P ( x ) Q ( x ) D_{KL} = \int_xP(x)log\frac{P(x)}{Q(x)} DKL=xP(x)logQ(x)P(x)

具体知识参考https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence

pytorch 实现

torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction=‘mean’, log_target=False)

The Kullback-Leibler divergence Loss

See KLDivLoss for details.

  • Parameters

    input – Tensor of arbitrary shape

    target – Tensor of the same shape as input

    size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. Default: True

    reduce (bool, optional) – Deprecated (see reduction). By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per batch element instead and ignores size_average. Default: True

    reduction (string*,* optional) – Specifies the reduction to apply to the output: 'none' | 'batchmean' | 'sum' | 'mean'. 'none': no reduction will be applied 'batchmean': the sum of the output will be divided by the batchsize 'sum': the output will be summed 'mean': the output will be divided by the number of elements in the output Default: 'mean'

    log_target (bool) – A flag indicating whether target is passed in the log space. It is recommended to pass certain distributions (like softmax) in the log space to avoid numerical issues caused by explicit log. Default: False

input与target是shape相同的tensor, 往往是 number * feature的大小,即从number个样本 计算出feature服从的emperical distribution。

size_average 和 reduce参数已经启用

输出的shape与input相同

需要调整的是reduction参数,常用的是mean和sum

  • 0
    点赞
  • 0
    评论
  • 0
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

©️2021 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值