Relational Knowledge Distillation------论文阅读笔记(CVPR2019)

RKD (Abstract+introduction)

在这里插入图片描述
在这里插入图片描述
关系知识提取(RKD),它传递输出的结构关系,而不是单个输出本身(图1)。
在这里插入图片描述
两种RKD损失:距离方向(二阶)和角度方向(三阶)

2. Related Work

比较感兴趣的:
在这里插入图片描述
训练浅层神经网络来模拟深层神经网络,并惩罚两个网络之间的逻辑差异,从而提高了浅层神经网络的精度。
在这里插入图片描述
LXu等人[43]提出了一种条件对抗网络来学习KD的损失函数。
Crowley等人[8]通过对模型的卷积通道进行分组并使用注意力转移对其进行训练来压缩模型。

Moonshine:Distilling with Cheap Convolutions

在这里插入图片描述

RKD的通用框架,并证明了其适用于各种任务。

3. Our Approach

在这里插入图片描述
FT和FS分别是教师和学生的函数,
函数f可以使用网络任何层(例如,隐藏层或softmax层)的输出来定义。
我们用不同数据示例的N元组的XN集表示, X2, X3

3.1 IKD(Individual KD)

这里针对的都是不同层的输出作KD。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
β 是为了弥补学生网络维度较小的一个线性映射(简单来说线性扩维)

3.2 RKD(Relational knowledge distillation)

在这里插入图片描述

φ 代表一种关系趋势函数
x1----xn是输入, XN则代表N个输入样本的
ti和si表示 老师和学生对应输入的输出

RKD训练学生和教师模型使用的相同关系势函数。由于这种潜力,它能够传递高阶属性的知识,即使教师和学生之间的输出维度不同,高阶属性对低阶属性也是不变的。
在这里插入图片描述

3.2.1 Distance-wise distillation loss (距离蒸馏损失)

**ψD**测量在输出表示空间中两个示例之间的欧氏距离
µ设置为小批量中X2对之间的平均距离
在这里插入图片描述

尤其是当 教师距离 和 学生距离 由于输出尺寸的差异 存在显著差异时,小批量距离标准化非常有用。
在这里插入图片描述
RKD,它并不不是强迫学生直接匹配教师的输出,而是鼓励学生关注输出的距离结构。

3.2.2 Angle-wise distillation loss

在这里插入图片描述

我们观察到角度损失通常允许更快的收敛和更好的性能。

3.2.3 Training with RKD

在这里插入图片描述

超参数 用来平衡loss的

In sampling tuples of examples for the proposed distillation losses, we simply use all possible tuples (i.e., pairs or triplets) from examples in a given mini-batch.
在mini-batch中抽取pairs 或者 tripets.

3.2.4 Distillation target layer

4. Experiments

在这里插入图片描述
我们在三个不同的任务上评估RKD:度量学习、分类和少量镜头学习。
为了公平比较,我们使用网格搜索调整竞争方法的超参数。

4.1. Metric learning

在这里插入图片描述

4.2. Image classification

在这里插入图片描述

4.3 Few-shot learning

在这里插入图片描述

代码实现

这里我有个idea,注释去掉了,有问题可以看下草稿纸截图帮助理解,有问题私聊评论区都ok

from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F


class RKDLoss(nn.Module):
    """Relational Knowledge Disitllation, CVPR2019"""
    def __init__(self, w_d=25, w_a=50):
        super(RKDLoss, self).__init__()
        self.w_d = w_d
        self.w_a = w_a

    def forward(self, f_s, f_t):
        student = f_s.view(f_s.shape[0], -1)
        teacher = f_t.view(f_t.shape[0], -1)

        # RKD distance loss
        with torch.no_grad():
            t_d = self.pdist(teacher, squared=False)
            mean_td = t_d[t_d > 0].mean()
            t_d = t_d / mean_td

        d = self.pdist(student, squared=False)
        mean_d = d[d > 0].mean()
        d = d / mean_d

        loss_d = F.smooth_l1_loss(d, t_d)

        # RKD Angle loss
        with torch.no_grad():
            td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
            norm_td = F.normalize(td, p=2, dim=2)
            t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)

        sd = (student.unsqueeze(0) - student.unsqueeze(1))
        norm_sd = F.normalize(sd, p=2, dim=2)
        s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)

        loss_a = F.smooth_l1_loss(s_angle, t_angle)

        loss = self.w_d * loss_d + self.w_a * loss_a

        return loss

    @staticmethod
    def pdist(e, squared=False, eps=1e-12):
        e_square = e.pow(2).sum(dim=1)
        prod = e @ e.t()
        res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

        if not squared:
            res = res.sqrt()

        res = res.clone()
        res[range(len(e)), range(len(e))] = 0
        return res

在这里插入图片描述

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值