知识蒸馏的一些知识点

训练过程包括学生模型与教师模型的损失计算。clf_loss表示学生模型对真实标签的损失,kd_loss是基于教师模型软目标的损失。在计算时,先除以温度T再进行softmax操作,以减小loss的方差。教师模型参数保持不变,仅优化学生模型,权重分配给两种损失,使学生模型向教师模型学习。
摘要由CSDN通过智能技术生成
def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            logits_pre = self.student_model(x)
            t_logits_pre = self.teacher_model(x, training=False)

            clf_loss_value = self.clf_loss(y, tf.math.softmax(logits_pre))
            kd_loss_value = self.kd_loss(tf.math.softmax(t_logits_pre/self.T), tf.math.softmax(logits_pre/self.T))
            sum_loss_value = self.alpha * clf_loss_value + (1-self.alpha) * kd_loss_value

        self.optimizer.minimize(sum_loss_value, self.student_model.trainable_variables, tape=tape)
        self.compiled_metrics.update_state(y, tf.math.softmax(logits_pre))

        self.sum_loss_tracker.update_state(sum_loss_value)
        self.clf_loss_tracker.update_state(clf_loss_value)
        self.kd_loss_tracker.update_state(kd_loss_value)

        return {m.name: m.result() for m in self.metrics}

上述训练过程可以看成模型蒸馏其实就是计算学生和老师的loss,让学生向着老师loss的大方向和自己的小方向做更新。所以第一步当然是得到两个不一样的loss,只不过老师模型就不需要更新了。

loss来自hard target的loss和soft target的loss,权值往往需要softtarget的权值更大些

clf_loss:对应的是students模型计算的损失

kd_loss:对应的是teachers模型输出的损失,T为温度,作用是使模型的训练困难化

teachers模型的参数不发生变化

值得注意的点是:logits_pre 和 t_logits_pre 都是特征,没有经过softmax,所以并不是先softmax再/T,而是先/T在进行softmax,这对最终形成的loss的方差可以大大减小,因为softmax是x的指数函数与sum的指数函数的比值。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值