class DistillKL(nn.Module):
def __init__(self, temperature):
super(DistillKL, self).__init__()
self.T = temperature
def forward(self, y_s, y_t):
# print("y_s.shape:",y_s.shape)
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.softmax(y_t/self.T, dim=1)
loss = F.kl_div(p_s, p_t.detach(), reduction='sum') * (self.T**2) / y_s.shape[0]
return loss
class KL(nn.Module):
def __init__(self, temperature,alpha,beta):
super(KL, self).__init__()
self.p = 2
self.kd = DistillKL(temperature=temperature)
self.alpha = alpha
self.beta = beta
def forward(self, o_s, o_t, g_s, g_t):
loss = self.alpha * self.kd(o_s, o_t)
loss += self.beta * sum([self.at_loss(f_s, f_t.detach()) for f_s, f_t in zip(g_s, g_t)])
return loss
def at_loss(self, f_s, f_t):
return (self.at(f_s) - self.at(f_t)).pow(2).mean()
def at(self, f):
return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
第一个类是只有logits蒸馏的损失度量函数,第二个类是结合了logtis和特征蒸馏的损失函数。