pytorch中的KL散度详解torch.nn.functional.kl_div

KL散度是衡量两个概率分布相似度的指标,在PyTorch中通过F.kl_div函数实现。该函数接受对数概率和目标概率分布作为输入,支持sum,mean,none三种_reduction_模式来处理输出。默认行为是计算平均KL散度。示例展示了如何使用softmax转换logits并计算批次的平均KL散度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

KL散度的公式

F.kl_div 是 PyTorch 中的一个函数,用于计算两个概率分布之间的 Kullback-Leibler (KL) 散度。

KL 散度是一种非对称的测量,用于衡量两个概率分布的相似度。如果两个分布完全相同,KL 散度为零;否则,KL 散度为一个正数。

在 PyTorch 中,F.kl_div 的输入是两个张量,其中第一个张量的每个元素应该是第二个张量对应元素的对数概率。因此,F.kl_div 的输入应该满足下面的条件:

  1. 第一个输入张量 input:这个张量的元素应该是第二个张量对应元素的对数概率,即 log P(x)

  2. 第二个输入张量 target:这个张量的元素是目标概率分布,即 Q(x)

F.kl_div 的计算公式是:

KL(P || Q) = Σ P(x) * log(P(x) / Q(x))

在 PyTorch 中,这个公式稍有改动,改为:

KL(P || Q) = Σ P(x) * (log(P(x)) - Q(x))

这是因为第一个输入张量已经是对数概率了。

torch.nn.functional.kl_div中的参数

PyTorch 的 KL 散度函数 F.kl_div 的行为取决于两个参数:size_averagereductionreduction 参数有三个可能的值:'none', 'sum', 和 'mean',而 size_average 只有在 reduction'mean' 时才有意义。

  • 如果 reduction='none',那么 F.kl_div 将返回一个和输入张量同样大小的新张量,其中每个元素表示相应位置的 KL 散度。

  • 如果 reduction='sum',那么 F.kl_div 将返回一个标量,表示所有元素的 KL 散度的总和。

  • 如果 reduction='mean',那么 F.kl_div 的行为将取决于 size_average 参数:

    • 如果 size_average=True,那么 F.kl_div 将返回一个标量,表示所有元素的 KL 散度的平均值。
    • 如果 size_average=False,那么 F.kl_div 将返回一个标量,表示所有元素的 KL 散度的总和。

这个函数的默认行为是 reduction='mean'size_average=True,也就是返回所有元素的 KL 散度的平均值。

如果你想要得到所有样本 KL 散度的平均值,你需要设置 reduction='mean'size_average=True

使用示例

假设我们有一个 2D 张量 logits,其维度是 [batch_size, num_classes],每一行代表一个样本,每一列代表一个类别。这样,logits[i, j] 就代表了第 i 个样本属于第 j 类的对数概率。我们通常将这个张量通过 softmax 函数转换成概率分布,即每一行的所有元素都是非负的,并且和为1。

例如,假设我们有如下的 logits:

logits = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])

这意味着我们有两个样本(批次大小为2),每个样本有三个类别的概率分布。我们可以使用 softmax 函数将 logits 转换成概率分布:

probs = torch.nn.functional.softmax(logits, dim=1)
print(probs)

输出:

tensor([[0.2271, 0.4116, 0.3613],
        [0.2900, 0.3200, 0.3900]])

这里,probs[i, j] 代表第 i 个样本属于第 j 类的概率。你可以看到每一行的所有元素都是非负的,并且和为1,因此每一行都是一个概率分布。

然后,我们可以用 F.kl_div 计算这个概率分布和一个目标概率分布之间的 KL 散度。假设目标概率分布是:

target_probs = torch.tensor([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]])

那么,我们可以计算 KL 散度如下:

loss = F.kl_div(torch.log(probs), target_probs, reduction='batchmean')
print(loss)

这里 torch.log(probs) 计算的是每个元素的对数概率,target_probs 是目标概率分布,reduction='batchmean' 表示计算批次的平均 KL 散度。

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值