知识蒸馏基本知识点

  1. 蒸馏loss
    KLDiv-最原始表达方式
import torch.nn as nn
import torch.nn.functional as F

# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单纯的kd_loss
class KLDiv(nn.Module):
    def __init__(self, temp=1.0):
        super(KLDiv, self).__init__()
        self.temp = temp

    def forward(self, student_preds, teacher_preds, **kwargs):
        soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)
        soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)
        kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()
        kd_loss *= self.temp ** 2
        return kd_loss

学生模型:对数概率分布,F.log_softmax()
教师模型: softmax,F.softmax()
温度系数:temp。为了提供更多信息,引入了“softmax温度”的概念,通过调整温度参数T,可以影响softmax函数生成的概率分布。当T=1时,得到标准的softmax函数,而当T增大时,softmax函数生成的概率分布变得更加柔和,提供了更多关于模型认为哪些类别与预测类别更相似的信息。这种调整温度的方法可以帮助传递大模型中所含的“暗知识”到小模型中。

  1. nn.KLDivLoss——KL散度损失
    作用:衡量连续分布的距离,常用于label smoothing,并且对离散采用的连续输出空间分布进行回归通常很有用
    在这里插入图片描述
import torch
import torch.nn as nn
import math

# # 手动实现
def validate_loss(output, target):
    val = 0
    for li_x, li_y in zip(output, target):
        for i, xy in enumerate(zip(li_x, li_y)):
            x, y = xy
            loss_val = y * (math.log(y, math.e) - x)
            val += loss_val
    return val / output.nelement()

torch.manual_seed(20)
loss = nn.KLDivLoss()
input = torch.Tensor([[-2, -6, -8], [-7, -1, -2], [-1, -9, -2.3], [-1.9, -2.8, -5.4]])
target = torch.Tensor([[0.8, 0.1, 0.1], [0.1, 0.7, 0.2], [0.5, 0.2, 0.3], [0.4, 0.3, 0.3]])
output = loss(input, target)
print("default loss:", output)

output = validate_loss(input, target)
print("validate loss:", output)
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值