蒸馏论文七(Variational Information Distillation)

本文介绍一种知识蒸馏的方法(Variational Information Distillation)。

1. 核心思想

作者定义了将互信息定义为:

在这里插入图片描述
如上式所述,互信息为 = 教师模型的熵值 - 已知学生模型的条件下的教师模型熵值。

我们有如下常识:当学生模型已知,能够使得教师模型的熵很小,这说明学生模型以及获得了能够恢复教师模型所需要的“压缩”知识,间接说明了此时学生模型已经学习的很好了,也就是说明上式中的H(t|s)很小,从而使得互信息I(t;s)会很大。因此,就可以通过最大化互信息的方式来进行蒸馏学习。
在这里插入图片描述
如图所示,学生网络与教师网络保持高互信息(MI),通过学习并估计教师网络中的分布,激发知识的传递,使相互信息最大化。

2. 损失函数

由于p(t|s)难以计算,作者根据IM算法,利用一个可变高斯q(t|s)来模拟p(t|s)

在这里插入图片描述
上述公式中的大于等于操作用到了KL散度的非负性。由于蒸馏过程中H(t)和需要学习的学生模型参数无关,因此最大化互信息就转换为最大化可变高斯分布的问题。

作者利用一个均值,方差可学习的高斯分布来模拟上述的q(t|s)
在这里插入图片描述
式子中可学习的方差定义如下:
在这里插入图片描述
其中阿尔法c是可学习参数。

class VIDLoss(nn.Module):
    """Variational Information Distillation for Knowledge Transfer (CVPR 2019),
    code from author: https://github.com/ssahn0215/variational-information-distillation"""
    def __init__(self,
                 num_input_channels,
                 num_mid_channel,
                 num_target_channels,
                 init_pred_var=5.0,
                 eps=1e-5):
        super(VIDLoss, self).__init__()

        def conv1x1(in_channels, out_channels, stride=1):
            return nn.Conv2d(in_channels, out_channels,kernel_size=1, padding=0,bias=False, stride=stride)
        
        # 通过一个卷积网络来模拟可变均值
        self.regressor = nn.Sequential(
            conv1x1(num_input_channels, num_mid_channel),
            nn.ReLU(),
            conv1x1(num_mid_channel, num_mid_channel),
            nn.ReLU(),
            conv1x1(num_mid_channel, num_target_channels),
        )

        # 可学习参数
        self.log_scale = torch.nn.Parameter(
            np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
            )

        self.eps = eps

    def forward(self, input, target):
        # pool for dimentsion match
        s_H, t_H = input.shape[2], target.shape[2]
        if s_H > t_H:
            input = F.adaptive_avg_pool2d(input, (t_H, t_H))
        elif s_H < t_H:
            target = F.adaptive_avg_pool2d(target, (s_H, s_H))
        else:
            pass

        # 均值方差
        pred_mean = self.regressor(input)
        pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
        pred_var = pred_var.view(1, -1, 1, 1)

        # 利用均值和方差可学习的高斯分布来模拟概率
        neg_log_prob = 0.5*(
            (pred_mean-target)**2/pred_var + torch.log(pred_var)
            )

        loss = torch.mean(neg_log_prob)
        return loss

3. 训练

# 损失函数
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)

s_n = [f.shape[1] for f in feat_s[1:-1]]
t_n = [f.shape[1] for f in feat_t[1:-1]]
criterion_kd = nn.ModuleList(
    [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)]
)

for idx, data in enumerate(train_loader):
	# ===================forward=====================
	loss_cls = criterion_cls(logit_s, target)
	loss_div = criterion_div(logit_s, logit_t)
	        
    g_s = feat_s[1:-1]
    g_t = feat_t[1:-1]
    loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
    loss_kd = sum(loss_group)
	
	loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
    # ===================backward=====================
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # ===================meters=====================
    batch_time.update(time.time() - end)
    end = time.time()

其中的feat_s是中间特征层,例如对于resnet8

if preact:
    return [f0, f1_pre, f2_pre, f3_pre, f4], x
else:
    return [f0, f1, f2, f3, f4], x

源代码
参考文献:CVPR 2019 | VID_最大化互信息知识蒸馏

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值