CVPR2019 | 关系型知识蒸馏法

CVPR 2019 | Relational Knowledge Distillation
https://github.com/HobbitLong/RepDistiller

1.蒸馏学习

由于大模型的拟合能力强,但计算效率低耗时大,而小模型的拟合能力弱,计算效率高。基于该特征,蒸馏学习的目的是让小模型学习大模型的拟合能力,在不改变计算效率的前提下提升小模型的拟合能力。如下图所示,传统的蒸馏学习(KD),直接根据小模型和大模型的输出值进行损失计算,使得小模型的输出能够靠近大模型的输出,以此来模型大模型的拟合能力。但这种方法很显然存在直观上的缺点,小模型只能学习大模型的输出表现,无法真正学习到大模型的结构信息。

传统的蒸馏学习的损失函数如下,其中ft表示教师模型的输出,fs表示学生模型的输出,L表示计算两者之间的距离。从损失函数中可以直观的看出,整个蒸馏学习过程中,小模型学习的就是大模型的输出表现,这种单点学习的方法是粗暴的,不具有结构性的。

2.关系型蒸馏学习

为了使得小模型能够更好的学习到大模型的结构信息,本文提出了关系型蒸馏学习法(RKD),如下图所示,RKD算法的核心是以多个教师模型的输出为结构单元,取代传统蒸馏学习中以单个教师模型输出为检测的方式,利用多输出组合成结构单元,更能体现出教师模型的结构化特征,使得学生模型得到更好的指导。

关系型蒸馏学习的损失函数如下,其中t1,t2…tn表示教师模型的多个输出,s1,s2…sn表示学生模型的多个输出,L表示计算两者之间的距离。与传统的蒸馏学习不同,关系型蒸馏学习的损失函数中还有一个构件结构信息的函数。可以使得学生模型学到教师模型中更加高效的信息表征能力。本文提出了两种表征结构信息的损失:距离蒸馏损失和角度蒸馏损失。

3.距离蒸馏损失(Distance-wise distillation loss)

基于距离的蒸馏损失的公式如下图所示,本文通过对每个batch中的样本进行两两距离计算,最终形成一个batch*batch大小的关系型结构输出。最终学生模型通过学习教师模型的结构输出,实现蒸馏学习。整体的代码如下所示。

 # RKD distance loss
with torch.no_grad():
    t_d = self.pdist(teacher, squared=False)
    mean_td = t_d[t_d > 0].mean()
    t_d = t_d / mean_td

d = self.pdist(student, squared=False)
mean_d = d[d > 0].mean()
d = d / mean_d
print("d:{},t_d:{}".format(d.size(),t_d.size()))
loss_d = F.smooth_l1_loss(d, t_d)

def pdist(e, squared=False, eps=1e-12):
	e_square = e.pow(2).sum(dim=1)
	prod = e @ e.t()
	res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

	if not squared:
		res = res.sqrt()

	res = res.clone()
	res[range(len(e)), range(len(e))] = 0

	# print("e_square:{}".format(e_square.size()))
	# print("e.t:{},prod:{}".format(e.t().size(),prod.size()))
	# print("unsqueeze(1):{},unsqueeze(0):{}".format(e_square.unsqueeze(1).size(),e_square.unsqueeze(0).size()))
	# print("res:{},len(e):{}".format(res.size(),len(e)))

	return res

4.角度蒸馏损失(Angle-wise distillation loss)

基于角度的蒸馏损失的公式如下图所示,本文通过对每个batch中的样本三三样本,计算两个角度,最终形成一个batchbatchbatch大小的关系型结构输出。最终学生模型通过学习教师模型的结构输出,实现蒸馏学习。整体的代码如下所示。

# RKD Angle loss
with torch.no_grad():
	td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
	norm_td = F.normalize(td, p=2, dim=2)
	t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
	print("unsqueeze(0):{},unsqueeze(1):{}".format(teacher.unsqueeze(0).size(),teacher.unsqueeze(1).size()))
	print("td:{},norm_td:{},norm_td.transpose(1, 2):{},t_angle:{}".format(td.size(),norm_td.size(),norm_td.transpose(1, 2).size(),t_angle.size()))

sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, dim=2)
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
loss_a = F.smooth_l1_loss(s_angle, t_angle)

5.关系型蒸馏效果

本文提出的关系型蒸馏学习方案在各个公开数据集上都证明了有效性,相较于传统的蒸馏学习方案,本文通过结构化输出的监督,获取了更好的监督学习结果。

RKD_LOSS整体代码请关注公众号【CV炼丹猿】,后台回复RKD获取。

  • 6
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
CVPR 2019中发表了一篇题为“迁移学习:无监督领域自适应的对比适应网络(Contrastive Adaptation Network for Unsupervised Domain Adaptation)”的论文。这篇论文主要介绍了一种用于无监督领域自适应的对比适应网络。 迁移学习是指将从一个源领域学到的知识应用到一个目标领域的任务中。在无监督领域自适应中,源领域和目标领域的标签信息是不可用的,因此算需要通过从源领域到目标领域的无监督样本对齐来实现知识迁移。 该论文提出的对比适应网络(Contrastive Adaptation Network,CAN)的目标是通过优化源领域上的特征表示,使其能够适应目标领域的特征分布。CAN的关键思想是通过对比损失来对源领域和目标领域的特征进行匹配。 具体地说,CAN首先通过一个共享的特征提取器来提取源领域和目标领域的特征表示。然后,通过对比损失函数来测量源领域和目标领域的特征之间的差异。对比损失函数的目标是使源领域和目标领域的特征在特定的度量空间中更加接近。最后,CAN通过最小化对比损失来优化特征提取器,以使源领域的特征能够适应目标领域。 该论文还对CAN进行了实验验证。实验结果表明,与其他无监督领域自适应方相比,CAN在多个图像分类任务上取得了更好的性能,证明了其有效性和优越性。 综上所述,这篇CVPR 2019论文介绍了一种用于无监督领域自适应的对比适应网络,通过对源领域和目标领域的特征进行对比学习,使得源领域的特征能够适应目标领域。该方在实验中展现了较好的性能,有望在无监督领域自适应任务中发挥重要作用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yuanCruise

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值