KL散度的公式
F.kl_div
是 PyTorch 中的一个函数,用于计算两个概率分布之间的 Kullback-Leibler (KL) 散度。
KL 散度是一种非对称的测量,用于衡量两个概率分布的相似度。如果两个分布完全相同,KL 散度为零;否则,KL 散度为一个正数。
在 PyTorch 中,F.kl_div
的输入是两个张量,其中第一个张量的每个元素应该是第二个张量对应元素的对数概率。因此,F.kl_div
的输入应该满足下面的条件:
-
第一个输入张量
input
:这个张量的元素应该是第二个张量对应元素的对数概率,即log P(x)
。 -
第二个输入张量
target
:这个张量的元素是目标概率分布,即Q(x)
。
F.kl_div
的计算公式是:
KL(P || Q) = Σ P(x) * log(P(x) / Q(x))
在 PyTorch 中,这个公式稍有改动,改为:
KL(P || Q) = Σ P(x) * (log(P(x)) - Q(x))
这是因为第一个输入张量已经是对数概率了。
torch.nn.functional.kl_div中的参数
PyTorch 的 KL 散度函数 F.kl_div
的行为取决于两个参数:size_average
和 reduction
。reduction
参数有三个可能的值:'none'
, 'sum'
, 和 'mean'
,而 size_average
只有在 reduction
为 'mean'
时才有意义。
-
如果
reduction='none'
,那么F.kl_div
将返回一个和输入张量同样大小的新张量,其中每个元素表示相应位置的 KL 散度。 -
如果
reduction='sum'
,那么F.kl_div
将返回一个标量,表示所有元素的 KL 散度的总和。 -
如果
reduction='mean'
,那么F.kl_div
的行为将取决于size_average
参数:- 如果
size_average=True
,那么F.kl_div
将返回一个标量,表示所有元素的 KL 散度的平均值。 - 如果
size_average=False
,那么F.kl_div
将返回一个标量,表示所有元素的 KL 散度的总和。
- 如果
这个函数的默认行为是 reduction='mean'
和 size_average=True
,也就是返回所有元素的 KL 散度的平均值。
如果你想要得到所有样本 KL 散度的平均值,你需要设置 reduction='mean'
和 size_average=True
。
使用示例
假设我们有一个 2D 张量 logits
,其维度是 [batch_size, num_classes]
,每一行代表一个样本,每一列代表一个类别。这样,logits[i, j]
就代表了第 i 个样本属于第 j 类的对数概率。我们通常将这个张量通过 softmax 函数转换成概率分布,即每一行的所有元素都是非负的,并且和为1。
例如,假设我们有如下的 logits:
logits = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])
这意味着我们有两个样本(批次大小为2),每个样本有三个类别的概率分布。我们可以使用 softmax 函数将 logits 转换成概率分布:
probs = torch.nn.functional.softmax(logits, dim=1)
print(probs)
输出:
tensor([[0.2271, 0.4116, 0.3613],
[0.2900, 0.3200, 0.3900]])
这里,probs[i, j]
代表第 i 个样本属于第 j 类的概率。你可以看到每一行的所有元素都是非负的,并且和为1,因此每一行都是一个概率分布。
然后,我们可以用 F.kl_div
计算这个概率分布和一个目标概率分布之间的 KL 散度。假设目标概率分布是:
target_probs = torch.tensor([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]])
那么,我们可以计算 KL 散度如下:
loss = F.kl_div(torch.log(probs), target_probs, reduction='batchmean')
print(loss)
这里 torch.log(probs)
计算的是每个元素的对数概率,target_probs
是目标概率分布,reduction='batchmean'
表示计算批次的平均 KL 散度。