pytorch中的KL散度详解torch.nn.functional.kl_div

KL散度是衡量两个概率分布相似度的指标,在PyTorch中通过F.kl_div函数实现。该函数接受对数概率和目标概率分布作为输入,支持sum,mean,none三种_reduction_模式来处理输出。默认行为是计算平均KL散度。示例展示了如何使用softmax转换logits并计算批次的平均KL散度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

KL散度的公式

F.kl_div 是 PyTorch 中的一个函数,用于计算两个概率分布之间的 Kullback-Leibler (KL) 散度。

KL 散度是一种非对称的测量,用于衡量两个概率分布的相似度。如果两个分布完全相同,KL 散度为零;否则,KL 散度为一个正数。

在 PyTorch 中,F.kl_div 的输入是两个张量,其中第一个张量的每个元素应该是第二个张量对应元素的对数概率。因此,F.kl_div 的输入应该满足下面的条件:

  1. 第一个输入张量 input:这个张量的元素应该是第二个张量对应元素的对数概率,即 log P(x)

  2. 第二个输入张量 target:这个张量的元素是目标概率分布,即 Q(x)

F.kl_div 的计算公式是:

KL(P || Q) = Σ P(x) * log(P(x) / Q(x))

在 PyTorch 中,这个公式稍有改动,改为:

KL(P || Q) = Σ P(x) * (log(P(x)) - Q(x))

这是因为第一个输入张量已经是对数概率了。

torch.nn.functional.kl_div中的参数

PyTorch 的 KL 散度函数 F.kl_div 的行为取决于两个参数:size_averagereductionreduction 参数有三个可能的值:'none', 'sum', 和 'mean',而 size_average 只有在 reduction'mean' 时才有意义。

  • 如果 reduction='none',那么 F.kl_div 将返回一个和输入张量同样大小的新张量,其中每个元素表示相应位置的 KL 散度。

  • 如果 reduction='sum',那么 F.kl_div 将返回一个标量,表示所有元素的 KL 散度的总和。

  • 如果 reduction='mean',那么 F.kl_div 的行为将取决于 size_average 参数:

    • 如果 size_average=True,那么 F.kl_div 将返回一个标量,表示所有元素的 KL 散度的平均值。
    • 如果 size_average=False,那么 F.kl_div 将返回一个标量,表示所有元素的 KL 散度的总和。

这个函数的默认行为是 reduction='mean'size_average=True,也就是返回所有元素的 KL 散度的平均值。

如果你想要得到所有样本 KL 散度的平均值,你需要设置 reduction='mean'size_average=True

使用示例

假设我们有一个 2D 张量 logits,其维度是 [batch_size, num_classes],每一行代表一个样本,每一列代表一个类别。这样,logits[i, j] 就代表了第 i 个样本属于第 j 类的对数概率。我们通常将这个张量通过 softmax 函数转换成概率分布,即每一行的所有元素都是非负的,并且和为1。

例如,假设我们有如下的 logits:

logits = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])

这意味着我们有两个样本(批次大小为2),每个样本有三个类别的概率分布。我们可以使用 softmax 函数将 logits 转换成概率分布:

probs = torch.nn.functional.softmax(logits, dim=1)
print(probs)

输出:

tensor([[0.2271, 0.4116, 0.3613],
        [0.2900, 0.3200, 0.3900]])

这里,probs[i, j] 代表第 i 个样本属于第 j 类的概率。你可以看到每一行的所有元素都是非负的,并且和为1,因此每一行都是一个概率分布。

然后,我们可以用 F.kl_div 计算这个概率分布和一个目标概率分布之间的 KL 散度。假设目标概率分布是:

target_probs = torch.tensor([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]])

那么,我们可以计算 KL 散度如下:

loss = F.kl_div(torch.log(probs), target_probs, reduction='batchmean')
print(loss)

这里 torch.log(probs) 计算的是每个元素的对数概率,target_probs 是目标概率分布,reduction='batchmean' 表示计算批次的平均 KL 散度。

### YOLOv8 知识蒸馏代码实现与解释 #### 1. 环境配置 为了确保能够顺利运行YOLOv8的知识蒸馏代码,环境配置至关重要。建议使用Python虚拟环境来管理依赖项,并安装必要的库和工具[^1]。 ```bash conda create -n yolov8_distillation python=3.9 conda activate yolov8_distillation pip install ultralytics torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 ``` #### 2. 数据准备 数据集对于训练模型非常重要,在进行知识蒸馏之前需准备好相应的图像分类或目标检测数据集并按照指定格式整理好文件结构。 #### 3. Logits-Based 蒸馏方法 Logits-Based 方法是最简单的知识蒸馏形式之一,通过让小型学生网络模仿大型教师网络的输出分布来进行学习。具体来说就是最小化两者之间的差异损失函数: \[ L_{distill} = \frac{1}{N}\sum_i^N KL(\text{softmax}(T\cdot s(x_i)) || \text{softmax}(T\cdot t(x_i))) \] 其中\(s(x)\)表示学生模型预测值;\(t(x)\)代表老师模型预测结果;\(KL\)指代Kullback-Leibler;而参数\(T>0\)则用来调整温以控制软概率分布的程。 ```python import torch.nn.functional as F def logits_based_loss(student_logits, teacher_logits, temperature=4): """计算基于logits的知识蒸馏损失""" soft_student = F.log_softmax(student_logits / temperature, dim=-1) soft_teacher = F.softmax(teacher_logits / temperature, dim=-1) return F.kl_div( soft_student, soft_teacher, reduction="batchmean" ) * (temperature ** 2) ``` #### 4. Feature-Based 蒸馏方法 Feature-Based 方式则是提取中间层特征图作为监督信号传递给学生网路,从而使得其内部表征更加接近于教师模型。通常采用均方误差(MSE)或其他相似性测衡量两者的差距: \[ L_{feat\_distill}=\left \| f_s(X)-f_t(X) \right \|_F^{2} \] 这里\(f_s()\) 和 \(f_t()\),分别对应着学生和老师的某一层激活响应矩阵; 符号\(||\cdot||_F\) 表明 Frobenius范数运算操作。 ```python from functools import partial class FeatureDistiller(nn.Module): def __init__(self, student_model, teacher_model, layers=('layer2', 'layer3')): super().__init__() self.student_features = [] self.teacher_features = [] # 注册钩子获取特定层的特征图 for layer_name in layers: getattr(student_model.model[layer_name], "register_forward_hook")(partial(self._hook_fn, is_student=True)) getattr(teacher_model.model[layer_name], "register_forward_hook")(partial(self._hook_fn, is_student=False)) def _hook_fn(self, module, input, output, is_student): if is_student: self.student_features.append(output.detach()) else: self.teacher_features.append(output.detach()) def forward(self, inputs): outputs = {} with torch.no_grad(): _ = self.teacher(inputs) _ = self.student(inputs) feature_losses = [ F.mse_loss(s_feat, t_feat) for s_feat, t_feat in zip(self.student_features, self.teacher_features) ] total_feature_loss = sum(feature_losses)/len(feature_losses) outputs['total_feature_loss'] = total_feature_loss return outputs ```
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值