交叉熵函数与kl散度的区别

公式上的区别

手动计算的方式展示如何实现这两个损失函数

交叉熵损失函数

import torch
import torch.nn.functional as F

# 模型的输出 logits 和真实标签 target
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
target = torch.tensor([0, 1])  # 真实标签

# 计算 softmax 以获得预测概率
pred_probs = F.softmax(logits, dim=1)

# 将 target 转换为 one-hot 编码
target_one_hot = F.one_hot(target, num_classes=logits.size(1))

# 交叉熵损失公式:L = - Σ y * log(ŷ)
cross_entropy_loss = - torch.sum(target_one_hot * torch.log(pred_probs)) / logits.size(0)

print('手动实现的交叉熵损失:', cross_entropy_loss)

kl散度

import torch
import torch.nn.functional as F

# 模型的输出 logits 和目标分布 target_probs
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
target_probs = torch.tensor([[0.7, 0.2, 0.1], [0.1, 0.7, 0.2]])  # 目标分布

# 将 logits 转换为 log softmax
logits_log_softmax = F.log_softmax(logits, dim=1)

# KL 散度公式:D_KL(P || Q) = Σ P * (log P - log Q)
kl_div_loss = torch.sum(target_probs * (torch.log(target_probs) - logits_log_softmax)) / logits.size(0)

print('手动实现的KL散度:', kl_div_loss)

官方打包好的函数

交叉熵损失 (Cross Entropy Loss) 官方实现

import torch
import torch.nn as nn

# 创建交叉熵损失函数实例
cross_entropy_loss_fn = nn.CrossEntropyLoss()

# 假设模型的输出 logits 和真实标签 targets
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
targets = torch.tensor([0, 1])  # 真实标签 (整数形式)

# 计算交叉熵损失
loss = cross_entropy_loss_fn(logits, targets)
print('官方交叉熵损失:', loss)

KL 散度损失 (KL Divergence Loss) 官方实现

import torch
import torch.nn as nn
import torch.nn.functional as F

# 创建KL散度损失函数实例
kl_div_loss_fn = nn.KLDivLoss(reduction='batchmean')  # 使用 'batchmean' 计算每个样本的平均损失

# 假设模型的输出 logits 和目标分布 target_probs
logits = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
target_probs = torch.tensor([[0.7, 0.2, 0.1], [0.1, 0.7, 0.2]])  # 目标分布 (已经是 softmax 概率)

# 计算 log softmax 以用于 KL 散度
logits_log_softmax = F.log_softmax(logits, dim=1)

# 计算KL散度损失
kl_loss = kl_div_loss_fn(logits_log_softmax, target_probs)
print('官方KL散度损失:', kl_loss)

加上温度

是的,KL 散度在深度学习中,尤其是知识蒸馏(Knowledge Distillation)中,常常与温度参数(Temperature, TTT)结合起来使用。温度调节可以让模型的预测分布更加平滑,从而在蒸馏过程中更有效地传递知识。下面将解释为什么 KL 散度与温度结合,以及如何使用温度参数。

温度在知识蒸馏中的作用

在知识蒸馏中,通常有一个教师模型(Teacher Model)和一个学生模型(Student Model)。教师模型的输出概率分布用来指导学生模型的训练,但直接使用教师模型的概率分布往往过于“尖锐”(即,教师模型的 softmax 输出大部分概率集中在正确类别)。为了使分布更加平滑,加入了温度参数。

温度对 softmax 的影响

softmax 函数将模型的 logits 转换为概率分布。加上温度参数 TTT 后的 softmax 表达式为:

  • 当 T=1T = 1T=1,softmax 正常工作,输出标准的概率分布。
  • 当 T>1T > 1T>1,softmax 输出变得更加平滑(分布更加均匀)。
  • 当 T<1T < 1T<1,softmax 输出更加“尖锐”,即概率分布更加接近 one-hot 编码。

在知识蒸馏中,通过引入较高的温度 TTT,可以让教师模型输出的概率分布变得更加平滑,从而包含更多类的信息,帮助学生模型更好地学习。

温度结合 KL 散度

在知识蒸馏过程中,学生模型通常通过最小化学生模型与教师模型之间的KL 散度来学习教师模型的输出分布。引入温度后,KL 散度的损失计算公式如下:

其中:

  • T2是为了平衡梯度的影响,避免由于高温度导致的梯度缩小。
  • DKL​ 表示 KL 散度,用于比较教师模型和学生模型的概率分布。

PyTorch 实现带温度的 KL 散度

你可以在 PyTorch 中手动实现带温度参数的 KL 散度,如下所示:

import torch
import torch.nn.functional as F

def distillation_kl_divergence_loss(logits_student, logits_teacher, temperature):
    """
    计算带温度参数的KL散度,用于知识蒸馏
    :param logits_student: 学生模型的logits (未经过softmax)
    :param logits_teacher: 教师模型的logits (未经过softmax)
    :param temperature: 温度参数T
    :return: 知识蒸馏中的KL散度损失
    """
    # 计算log softmax(学生和教师模型的输出都经过温度缩放)
    log_probs_student = F.log_softmax(logits_student / temperature, dim=1)
    probs_teacher = F.softmax(logits_teacher / temperature, dim=1)

    # 计算KL散度损失,并乘以T^2
    kl_div_loss = F.kl_div(log_probs_student, probs_teacher, reduction='batchmean') * (temperature ** 2)
    
    return kl_div_loss

# 示例
logits_student = torch.tensor([[2.0, 0.5, 0.1], [0.3, 2.5, 0.8]], requires_grad=True)
logits_teacher = torch.tensor([[2.1, 0.6, 0.1], [0.2, 2.6, 0.7]], requires_grad=False)
temperature = 2.0  # 温度参数

# 计算带温度的KL散度损失
kl_loss = distillation_kl_divergence_loss(logits_student, logits_teacher, temperature)
print('带温度的KL散度损失:', kl_loss)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值