nn.KLDivLoss

1. 基本概念

KL散度(Kullback-Leibler Divergence) 是衡量两个概率分布 P P P(目标分布)和 Q Q Q(预测分布)差异的非对称指标,定义为:
D K L ( P ∥ Q ) = ∑ P ( x ) ⋅ ( log ⁡ P ( x ) − log ⁡ Q ( x ) ) D_{KL}(P \parallel Q) = \sum P(x) \cdot \left( \log P(x) - \log Q(x) \right) DKL(PQ)=P(x)(logP(x)logQ(x))
在深度学习中,PyTorch 的 nn.KLDivLoss 用于计算这一散度,但需注意输入形式。


2. 输入与目标要求

输入(Input)
必须是 对数概率(log-probabilities),即模型输出需先经过 log_softmax 处理,得到 log ⁡ Q \log Q logQ

# 示例:模型输出 logits 后处理
logits = model(x)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

目标(Target)
必须是 概率分布,且形状与输入一致。例如:
• 分类任务中,若使用软标签(Soft Targets),目标可能是平滑后的概率(如知识蒸馏中的教师输出)。
• 若使用硬标签(One-Hot 编码),需手动转换为概率(如 [0, 1, 0])。


3. 参数 reduction='batchmean'

功能
对损失的计算方式进行归一化:
'batchmean':所有样本的 KL 散度求和后,除以 Batch Size(即平均每个样本的损失)。
对比其他模式
'mean':总损失除以 Batch Size × 类别数,通常错误。
'sum':直接求和,不归一化。
'none':保留每个样本的损失。

数学公式
Loss = 1 N ∑ i = 1 N ∑ c = 1 C P i ( c ) ⋅ ( log ⁡ P i ( c ) − log ⁡ Q i ( c ) ) \text{Loss} = \frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C P_i(c) \cdot \left( \log P_i(c) - \log Q_i(c) \right) Loss=N1i=1Nc=1CPi(c)(logPi(c)logQi(c))
其中 N N N 为 Batch Size, C C C 为类别数。


4. 使用场景

知识蒸馏(Knowledge Distillation)
教师模型生成软目标(Soft Targets),学生模型通过 KLDivLoss 匹配教师输出的分布。

# 示例:知识蒸馏损失计算
teacher_probs = torch.softmax(teacher_logits / temperature, dim=-1)
student_log_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
loss = nn.KLDivLoss(reduction='batchmean')(student_log_probs, teacher_probs)

生成模型
衡量生成数据分布与真实数据分布的差异。

标签平滑(Label Smoothing)
目标分布为平滑后的概率(如 [0.1, 0.8, 0.1]),避免模型过度自信。


5. 与交叉熵损失的关系

交叉熵损失(CrossEntropyLoss)
CrossEntropy = − ∑ P ( x ) ⋅ log ⁡ Q ( x ) \text{CrossEntropy} = -\sum P(x) \cdot \log Q(x) CrossEntropy=P(x)logQ(x)
其中 P P P 是 One-Hot 编码的真实标签, Q Q Q 是预测概率。

KL散度与交叉熵的关系
D K L ( P ∥ Q ) = CrossEntropy ( P , Q ) − Entropy ( P ) D_{KL}(P \parallel Q) = \text{CrossEntropy}(P, Q) - \text{Entropy}(P) DKL(PQ)=CrossEntropy(P,Q)Entropy(P)
P P P 是固定分布(如分类任务中的真实标签),最小化 KL 散度等价于最小化交叉熵。

输入差异
CrossEntropyLoss 直接接受 未归一化的 Logits,内部自动进行 log_softmax
KLDivLoss 需手动对输入进行 log_softmax,且目标为概率。


6. 代码示例
import torch
import torch.nn as nn

# 模型输出(Logits)
logits = torch.randn(3, 5)  # Batch Size=3, 类别数=5

# 目标分布(概率形式,每行和为1)
target_probs = torch.tensor([
    [0.1, 0.2, 0.5, 0.1, 0.1],
    [0.4, 0.3, 0.2, 0.1, 0.0],
    [0.0, 0.0, 0.1, 0.2, 0.7]
])

# 处理输入:应用 log_softmax
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

# 计算 KL 散度损失
criterion = nn.KLDivLoss(reduction='batchmean')
loss = criterion(log_probs, target_probs)

print(loss)  # 输出标量损失值

7. 常见问题
  1. 损失为负数?
    • 检查输入是否经过 log_softmax。若输入是原始 Logits,结果可能非法。
    • 确保目标概率和为 1。

  2. 与 CrossEntropyLoss 的互换性
    • 若目标为 One-Hot 编码,需转换为概率矩阵(如 [0, 1, 0] → [0.0, 1.0, 0.0]),此时两者等价。

  3. 梯度爆炸/消失
    • 确保输入值范围合理,避免 log_softmax 后出现极端值(如 log ⁡ 0 \log 0 log0)。


总结

nn.KLDivLoss(reduction='batchmean') 是衡量两个概率分布差异的核心工具,适用于需要明确匹配分布形状的任务(如知识蒸馏)。使用时需严格处理输入(log_softmax)和目标(概率分布)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值