1. 基本概念
KL散度(Kullback-Leibler Divergence) 是衡量两个概率分布
P
P
P(目标分布)和
Q
Q
Q(预测分布)差异的非对称指标,定义为:
D
K
L
(
P
∥
Q
)
=
∑
P
(
x
)
⋅
(
log
P
(
x
)
−
log
Q
(
x
)
)
D_{KL}(P \parallel Q) = \sum P(x) \cdot \left( \log P(x) - \log Q(x) \right)
DKL(P∥Q)=∑P(x)⋅(logP(x)−logQ(x))
在深度学习中,PyTorch 的 nn.KLDivLoss
用于计算这一散度,但需注意输入形式。
2. 输入与目标要求
• 输入(Input):
必须是 对数概率(log-probabilities),即模型输出需先经过 log_softmax
处理,得到
log
Q
\log Q
logQ。
# 示例:模型输出 logits 后处理
logits = model(x)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
• 目标(Target):
必须是 概率分布,且形状与输入一致。例如:
• 分类任务中,若使用软标签(Soft Targets),目标可能是平滑后的概率(如知识蒸馏中的教师输出)。
• 若使用硬标签(One-Hot 编码),需手动转换为概率(如 [0, 1, 0]
)。
3. 参数 reduction='batchmean'
• 功能:
对损失的计算方式进行归一化:
• 'batchmean'
:所有样本的 KL 散度求和后,除以 Batch Size(即平均每个样本的损失)。
• 对比其他模式:
◦ 'mean'
:总损失除以 Batch Size × 类别数
,通常错误。
◦ 'sum'
:直接求和,不归一化。
◦ 'none'
:保留每个样本的损失。
• 数学公式:
Loss
=
1
N
∑
i
=
1
N
∑
c
=
1
C
P
i
(
c
)
⋅
(
log
P
i
(
c
)
−
log
Q
i
(
c
)
)
\text{Loss} = \frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C P_i(c) \cdot \left( \log P_i(c) - \log Q_i(c) \right)
Loss=N1∑i=1N∑c=1CPi(c)⋅(logPi(c)−logQi(c))
其中
N
N
N 为 Batch Size,
C
C
C 为类别数。
4. 使用场景
• 知识蒸馏(Knowledge Distillation):
教师模型生成软目标(Soft Targets),学生模型通过 KLDivLoss
匹配教师输出的分布。
# 示例:知识蒸馏损失计算
teacher_probs = torch.softmax(teacher_logits / temperature, dim=-1)
student_log_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
loss = nn.KLDivLoss(reduction='batchmean')(student_log_probs, teacher_probs)
• 生成模型:
衡量生成数据分布与真实数据分布的差异。
• 标签平滑(Label Smoothing):
目标分布为平滑后的概率(如 [0.1, 0.8, 0.1]
),避免模型过度自信。
5. 与交叉熵损失的关系
• 交叉熵损失(CrossEntropyLoss):
CrossEntropy
=
−
∑
P
(
x
)
⋅
log
Q
(
x
)
\text{CrossEntropy} = -\sum P(x) \cdot \log Q(x)
CrossEntropy=−∑P(x)⋅logQ(x)
其中
P
P
P 是 One-Hot 编码的真实标签,
Q
Q
Q 是预测概率。
• KL散度与交叉熵的关系:
D
K
L
(
P
∥
Q
)
=
CrossEntropy
(
P
,
Q
)
−
Entropy
(
P
)
D_{KL}(P \parallel Q) = \text{CrossEntropy}(P, Q) - \text{Entropy}(P)
DKL(P∥Q)=CrossEntropy(P,Q)−Entropy(P)
当
P
P
P 是固定分布(如分类任务中的真实标签),最小化 KL 散度等价于最小化交叉熵。
• 输入差异:
• CrossEntropyLoss
直接接受 未归一化的 Logits,内部自动进行 log_softmax
。
• KLDivLoss
需手动对输入进行 log_softmax
,且目标为概率。
6. 代码示例
import torch
import torch.nn as nn
# 模型输出(Logits)
logits = torch.randn(3, 5) # Batch Size=3, 类别数=5
# 目标分布(概率形式,每行和为1)
target_probs = torch.tensor([
[0.1, 0.2, 0.5, 0.1, 0.1],
[0.4, 0.3, 0.2, 0.1, 0.0],
[0.0, 0.0, 0.1, 0.2, 0.7]
])
# 处理输入:应用 log_softmax
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
# 计算 KL 散度损失
criterion = nn.KLDivLoss(reduction='batchmean')
loss = criterion(log_probs, target_probs)
print(loss) # 输出标量损失值
7. 常见问题
-
损失为负数?
• 检查输入是否经过log_softmax
。若输入是原始 Logits,结果可能非法。
• 确保目标概率和为 1。 -
与 CrossEntropyLoss 的互换性:
• 若目标为 One-Hot 编码,需转换为概率矩阵(如[0, 1, 0] → [0.0, 1.0, 0.0]
),此时两者等价。 -
梯度爆炸/消失:
• 确保输入值范围合理,避免log_softmax
后出现极端值(如 log 0 \log 0 log0)。
总结
nn.KLDivLoss(reduction='batchmean')
是衡量两个概率分布差异的核心工具,适用于需要明确匹配分布形状的任务(如知识蒸馏)。使用时需严格处理输入(log_softmax
)和目标(概率分布)。