Similarity-Preserving Knowledge Distillation
Abstract
在训练网络的过程中,语义相似的输入倾向于在引发相似的激活模式。保持相似性的知识提取指导学生网络的训练,使在教师网络中产生相似(不同)激活的输入指导学生网络中产生相似(不同)激活。与以前的提取方法不同,学生不需要模仿教师的表示空间,而是要在自己的表示空间中保留成对的相似性。
1. Introduction
图1.保持相似性的知识提取指导学生网络的训练,使得在预先训练的教师网络中产生相似(不同)激活的输入对在学生网络中产生相似(不同)激活。给定一小批b张输入的图像,我们从激活图中导出b×b成对相似矩阵,并计算学生和教师产生的矩阵上的蒸馏损失。
这篇文章的想法和RKD的想法及其相似,不知道谁参考谁的,让人难免有点copy的感觉。
贡献:
- 我们引入了保持相似性的知识提取,这是一种新的知识提取形式,它使用每个输入小批量中的成对激活相似性来监督学生网络和经过训练的教师网络的训练。(实例间的相似矩阵)
- 我们在三个公共数据集上实验验证了我们的方法。我们的实验表明,保持相似性的知识提取不仅可以提高学生网络的训练效果,而且可以补充传统的知识提取方法。
2. Method
注释 :
- AT AS 来说是对应的特征图 ,
- 通道 和 h w 都可以不一致,
- l 和 l’ 可以是corresponding layer 对应层或者说相同深度的层。
这里和RKD已经非常相似。
做法也比较简单,求出实例间的关系矩阵,且归一化。 GT GS
where γ is a balancing hyperparameter.
3. Experiments
3.1. CIFAR-10
3.2. Transfer learning combining distillation with fine-tuning
3.3. CINIC-10
代码
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
class Similarity(nn.Module):
"""Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
def __init__(self):
super(Similarity, self).__init__()
def forward(self, g_s, g_t):
return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
def similarity_loss(self, f_s, f_t):
bsz = f_s.shape[0]
f_s = f_s.view(bsz, -1)
f_t = f_t.view(bsz, -1)
G_s = torch.mm(f_s, torch.t(f_s))
# G_s = G_s / G_s.norm(2)
G_s = torch.nn.functional.normalize(G_s)
G_t = torch.mm(f_t, torch.t(f_t))
# G_t = G_t / G_t.norm(2)
G_t = torch.nn.functional.normalize(G_t)
G_diff = G_t - G_s
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
return loss