- 蒸馏loss
KLDiv-最原始表达方式
import torch.nn as nn
import torch.nn.functional as F
# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单纯的kd_loss
class KLDiv(nn.Module):
def __init__(self, temp=1.0):
super(KLDiv, self).__init__()
self.temp = temp
def forward(self, student_preds, teacher_preds, **kwargs):
soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)
soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)
kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()
kd_loss *= self.temp ** 2
return kd_loss
学生模型:对数概率分布,F.log_softmax()
教师模型: softmax,F.softmax()
温度系数:temp。为了提供更多信息,引入了“softmax温度”的概念,通过调整温度参数T,可以影响softmax函数生成的概率分布。当T=1时,得到标准的softmax函数,而当T增大时,softmax函数生成的概率分布变得更加柔和,提供了更多关于模型认为哪些类别与预测类别更相似的信息。这种调整温度的方法可以帮助传递大模型中所含的“暗知识”到小模型中。
- nn.KLDivLoss——KL散度损失
作用:衡量连续分布的距离,常用于label smoothing,并且对离散采用的连续输出空间分布进行回归通常很有用
import torch
import torch.nn as nn
import math
# # 手动实现
def validate_loss(output, target):
val = 0
for li_x, li_y in zip(output, target):
for i, xy in enumerate(zip(li_x, li_y)):
x, y = xy
loss_val = y * (math.log(y, math.e) - x)
val += loss_val
return val / output.nelement()
torch.manual_seed(20)
loss = nn.KLDivLoss()
input = torch.Tensor([[-2, -6, -8], [-7, -1, -2], [-1, -9, -2.3], [-1.9, -2.8, -5.4]])
target = torch.Tensor([[0.8, 0.1, 0.1], [0.1, 0.7, 0.2], [0.5, 0.2, 0.3], [0.4, 0.3, 0.3]])
output = loss(input, target)
print("default loss:", output)
output = validate_loss(input, target)
print("validate loss:", output)