深度学习剖根问底:交叉熵和KL散度的区别

本文探讨了交叉熵与相对熵在机器学习中的应用。交叉熵作为一种损失函数,用于衡量真实分布与预测分布间的差异;相对熵(KL散度)则量化了两个概率分布之间的差距。文中详细解释了两者的数学定义及它们之间的联系,并指出优化交叉熵等效于最大化似然估计。

交叉熵可在神经网络(机器学习)中作为损失函数,p表示真实标记的分布,q则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量真实分布p与当前训练得到的概率分布q有多么大的差异。

相对熵(relative entropy)就是KL散度(Kullback–Leibler divergence),用于衡量两个概率分布之间的差异。

对于两个概率分布p(x)q(x) ,其相对熵的计算公式为:

\tt KL\it(p\parallel q)=-\int p(x)\ln q(x) dx -(-\int p(x)\ln p(x) dx)

注意:由于p(x) 和q(x) 在公式中的地位不是相等的,所以\tt KL \it(p\parallel q)\not\equiv \tt KL \it (q\parallel p)

相对熵的特点,是只有p(x)=q(x) 时,其值为0。若p(x) 和q(x) 略有差异,其值就会大于0。

相对熵公式的前半部分-\int p(x)\ln q(x)dx 就是交叉熵(cross entropy)。

p(x) 是数据的真实概率分布,q(x) 是由数据计算得到的概率分布。机器学习的目的就是希望q(x)尽可能地逼近甚至等于p(x) ,从而使得相对熵接近最小值0。由于真实的概率分布是固定的,相对熵公式的后半部分(-\int p(x)\ln p(x) dx) 就成了一个常数。相对熵的值大于等于0(https://zhuanlan.zhihu.com/p/28249050,这里给了证明),那么相对熵达到最小值的时候,也意味着交叉熵达到了最小值。对q(x) 的优化就等效于求交叉熵的最小值。另外,对交叉熵求最小值,也等效于求最大似然估计(maximum likelihood estimation)。

注意:交叉熵是衡量分布p与分布q的相似性,以前认为交叉熵的相似性越大,交叉熵的值就应该越大。但通过上面的推到可以看出,交叉熵得到两个分布的相似性是根据相对熵来的,所以相似性越大,交叉熵的值应该越小。

### 知识蒸馏中交叉熵KL区别及应用场景 在知识蒸馏的过程中,交叉熵KL是两种常用的损失函数。它们的使用方式适用场景有所不同,具体如下: #### 1. KL的定义与应用 KL(Kullback-Leibler Divergence)用于衡量两个概率分布之间的差异。在知识蒸馏中,KL被用来量化学生模型的输出分布与教师模型的输出分布之间的差异[^1]。通过最小化KL,学生模型能够学习到教师模型的概率分布,从而获得更优的性能。 公式表示为: \[ D_{KL}(P || Q) = \sum_{x} P(x) \log \frac{P(x)}{Q(x)} \] 其中 \(P\) 表示教师模型的输出分布,\(Q\) 表示学生模型的输出分布。 在实际实现中,`torch.nn.KLDivLoss` 是一个常用工具,用于计算两个分布之间的KL[^2]。由于KL直接比较了两个分布的差异,因此它特别适合用于知识蒸馏中的软标签训练。 #### 2. 交叉熵的定义与应用 交叉熵(Cross-Entropy)是一种衡量两个概率分布之间差异的指标,常用于分类任务中。其公式为: \[ H(P, Q) = -\sum_{x} P(x) \log Q(x) \] 在知识蒸馏中,交叉熵通常用于硬标签训练,即目标分布是一个one-hot向量[^3]。在这种情况下,交叉熵可以简化为对数损失(Log Loss),并直接用于优化模型参数。 #### 3. 交叉熵KL区别 - **目标分布**:KL适用于软标签训练,目标分布是教师模型的输出概率分布;而交叉熵更适合硬标签训练,目标分布通常是one-hot编码。 - **计算复杂**:KL需要同时考虑目标分布模型分布,计算相对复杂;交叉熵则仅需关注目标分布下的对数概率值,计算更为简单。 - **应用场景**:KL广泛应用于知识蒸馏中的软标签训练,特别是在深度学习模型(如DeepSeek等大型语言模型)中;交叉熵则更多用于传统的分类任务或硬标签训练场景。 #### 4. 实现代码示例 以下是一个简单的PyTorch代码示例,展示如何使用KL交叉熵进行知识蒸馏训练: ```python import torch import torch.nn as nn import torch.nn.functional as F # 假设教师模型学生模型的输出 teacher_logits = torch.tensor([[3.0, 1.0, 0.2], [2.0, 5.0, 1.0]]) student_logits = torch.tensor([[2.0, 1.5, 0.3], [1.5, 4.0, 1.2]]) # 转换为概率分布 temperature = 2.0 teacher_probs = F.softmax(teacher_logits / temperature, dim=-1) student_probs = F.log_softmax(student_logits / temperature, dim=-1) # 计算KL损失 kld_loss = nn.KLDivLoss(reduction='batchmean')(student_probs, teacher_probs) * (temperature ** 2) print(f"KL Divergence Loss: {kld_loss.item()}") # 计算交叉熵损失(假设硬标签为[0, 1]) hard_labels = torch.tensor([0, 1]) cross_entropy_loss = nn.CrossEntropyLoss()(student_logits, hard_labels) print(f"Cross Entropy Loss: {cross_entropy_loss.item()}") ``` ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值