【知识蒸馏|模型压缩】什么是教师网络Teacher Network?什么是学生网络Student Network?知识蒸馏(Knowledge Distillation,KD)的实质是“模型压缩”!

【知识蒸馏|模型压缩】什么是教师网络Teacher Network?什么是学生网络Student Network?知识蒸馏(Knowledge Distillation,KD)的实质是“模型压缩”!

【知识蒸馏|模型压缩】什么是教师网络Teacher Network?什么是学生网络Student Network?知识蒸馏(Knowledge Distillation,KD)的实质是“模型压缩”!



1.知识蒸馏的基础理论

1.1背景与概念

知识蒸馏(Knowledge Distillation, KD)最早由Geoffrey Hinton等人在2015年提出,主要用于模型压缩。其核心思想是通过训练一个轻量级的学生网络(Student Network),使其学习一个复杂、性能较好的教师网络(Teacher Network)中的知识。这种方法通过让学生网络模仿教师网络的预测输出,使得在不牺牲性能的前提下大幅度降低模型的复杂性。

知识蒸馏的动机在于,复杂的大型深度网络在许多应用中(如手机、物联网设备)计算成本过高,运行速度慢,且占用大量内存,而通过知识蒸馏,可以保持性能的同时,使用更加轻量化的模型

1.2知识蒸馏的原理

在知识蒸馏过程中,训练学生网络不仅要学习正确类别的标签,还要学习教师网络提供的软标签(soft labels)。软标签包含教师网络在各类之间的输出概率分布信息,这种分布比**原始硬标签(如one-hot编码)**提供了更多的知识,有助于学生网络更好地学习。

  • 软标签是通过使用温度参数 T T T调整softmax层得到的:
    在这里插入图片描述
    其中, z i z_i zi是模型在类别 i i i上的原始输出logit, T T T为温度参数。当 T > 1 T>1 T>1时,软化的概率分布能提供更多关于类间相似性的知识。

1.3知识蒸馏的损失函数

知识蒸馏的总损失通常由两部分组成:

  • (1)交叉熵损失(用于真实标签的学习):
    在这里插入图片描述
    其中, y i y_i yi是真实标签, p i p_i pi 是学生网络的输出。
  • (2)蒸馏损失(用于学习教师网络的知识):
    在这里插入图片描述
    其中, q i T q^T_i qiT p i T p^T_i piT分别是教师网络和学生网络的软标签。

总损失为二者的加权和:
在这里插入图片描述
其中, α α α为权重系数,用于平衡两个损失的贡献。

2.知识蒸馏的扩展与改进

自从Hinton提出知识蒸馏以来,研究者们提出了许多改进和扩展方法,以提升蒸馏的效果,适应不同任务和场景。

2.1基于特征的蒸馏

除了蒸馏教师网络的输出结果外,特征蒸馏通过让学生网络学习教师网络中间层的特征表示,进一步提升了学生网络的学习能力。这种方式特别适合处理学生网络和教师网络架构不同的情况。

2.2自蒸馏(Self-Distillation)

在自蒸馏中,教师网络和学生网络是同一个模型,或是将一个模型的不同阶段的输出作为教师网络的知识。这种方法不需要额外的教师网络,降低了训练成本。

2.3多教师蒸馏

多教师蒸馏从多个教师网络中提取知识,让学生网络从不同视角获取信息,从而提升模型的泛化能力

2.4蒸馏在强化学习中的应用

知识蒸馏不仅适用于监督学习,还可以扩展到强化学习中。在这种场景下,学生网络通过模仿教师网络的策略学习,提升其在复杂环境中的决策能力。

3.知识蒸馏的实际用途

知识蒸馏广泛应用于以下场景

  • 模型压缩:将大型网络压缩成轻量化的网络,常见于移动设备上的应用,如图像分类、目标检测。
  • 加速推理:通过知识蒸馏,学生网络比教师网络计算量少,从而加速推理速度。
  • 模型集成:将多个模型的输出蒸馏到单个学生模型上,减少模型集成的复杂性。
  • 迁移学习:利用已经训练好的教师网络,将知识迁移到更小的学生网络中,适应不同的任务和领域。
  • 跨领域学习:将不同任务中的知识通过蒸馏进行迁移,如自然语言处理模型蒸馏至计算机视觉模型。

4.知识蒸馏的简单实现

以下是使用PyTorch实现知识蒸馏的代码示例,展示了如何将教师网络的知识蒸馏到学生网络中

代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

# 定义简单的教师网络(Teacher Network)
class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64*12*12, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义较小的学生网络(Student Network)
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
        self.fc1 = nn.Linear(32*12*12, 64)
        self.fc2 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义知识蒸馏的损失函数
def distillation_loss(student_outputs, teacher_outputs, labels, T, alpha):
    # 蒸馏损失(使用软标签)
    distill_loss = F.kl_div(F.log_softmax(student_outputs / T, dim=1),
                            F.softmax(teacher_outputs / T, dim=1),
                            reduction='batchmean') * T * T
    # 交叉熵损失(使用真实标签)
    ce_loss = F.cross_entropy(student_outputs, labels)
    
    # 总损失是蒸馏损失和交叉熵损失的加权和
    return alpha * ce_loss + (1 - alpha) * distill_loss

# 数据准备
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# 初始化教师和学生网络
teacher_model = TeacherNet()
student_model = StudentNet()

# 假设我们已经训练好教师网络,这里直接加载预训练权重
# teacher_model.load_state_dict(torch.load('teacher_model.pth'))

# 定义优化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# 训练过程
teacher_model.eval()  # 教师网络处于评估模式,不参与训练
num_epochs = 5
T = 5  # 温度
alpha = 0.7  # 权重系数

for epoch in range(num_epochs):
    student_model.train()
    
    for images, labels in train_loader:
        # 前向传播
        with torch.no_grad():  # 教师网络不更新梯度
            teacher_outputs = teacher_model(images)
        
        student_outputs = student_model(images)
        
        # 计算损失
        loss = distillation_loss(student_outputs, teacher_outputs, labels, T, alpha)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

代码解释:

  • 1.TeacherNetStudentNet 类:定义了教师网络和学生网络。教师网络较为复杂,而学生网络是较小的网络,用于模型压缩。
  • 2.distillation_loss 函数:计算总损失,其中包括蒸馏损失(基于教师网络的软标签)和交叉熵损失(基于真实标签)。蒸馏损失使用了温度参数 T 来控制软标签的平滑程度。
  • 3.teacher_model.eval():教师网络不参与训练,只是提供知识,因此在训练时设为评估模式。
  • 4.with torch.no_grad():确保在计算教师网络的输出时不计算梯度,减少内存开销和计算量。
  • 5.optimizer.step():更新学生网络的权重,通过反向传播让学生网络学习到教师网络的知识。

5.知识蒸馏的论文推荐

(1)Fitnets: Hints for thin deep nets(ICLR 2015)

论文链接:https://arxiv.org/pdf/1412.6550

主要内容:

  • 虽然深度倾向于提高网络性能,但它也使基于梯度的训练变得更加困难,因为更深的网络往往更非线性。最近提出的知识蒸馏方法旨在获得小型且快速执行的模型,并且已经表明学生网络可以模仿更大的教师网络或网络集成的软输出
  • 在本文中,我们扩展了这一思想,允许训练比教师更深更薄的学生,不仅使用输出,还使用教师学习到的中间表示作为提示,以改善训练过程和学生的最终表现。由于学生中间隐藏层通常会小于教师的中间隐藏层,因此引入额外的参数将学生隐藏层映射到教师隐藏层的预测。这样就可以培养出更深入的学生,他们可以更好地概括或跑得更快,这种权衡取决于所选学生的能力。
  • 例如,在CIFAR-10上,一个参数少了近10.4倍的深度学生网络比一个更大的、最先进的教师网络表现得更好。
    在这里插入图片描述

(2)Training data-efficient image transformers & distillation through attention(ICML 2021)

论文链接:https://proceedings.mlr.press/v139/touvron21a/touvron21a.pdf

主要内容:

  • 最近,纯粹基于注意力的神经网络被用于解决图像分类等图像理解任务。这些高性能视觉Transformer使用大型基础设施使用数亿张图像进行预训练,从而限制了它们的采用。
  • 在这项工作中,我们仅使用一台计算机在不到3天的时间内就在ImageNet上训练了具有竞争力的无卷积Transformer。我们的参考视觉Transformer(86M个参数)在没有外部数据的情况下在ImageNet上实现了83.1%(单次裁剪)的顶级精度。我们还介绍了一种针对Transformer的师生策略。
  • 它依赖于一个蒸馏token,确保学生通过注意力从教师那里学习,通常是一个教师卷积神经网络。学习到的Transformer与ImageNet上的最新技术相比具有竞争力(85.2%的top-1 acc),在转移到其他任务时也是如此。我们将分享我们的代码和模型。
    在这里插入图片描述

(3)Knowledge distillation via softmax regression representation learning(ICLR 2021)

论文地址:https://openreview.net/pdf?id=ZzwDy_wiWv
代码地址:https://github.com/jingyang2017/KD_SRRL

主要内容:

  • 本文通过知识蒸馏解决了模型压缩问题。我们提倡一种优化学生网络倒数第二层的输出特征的方法,因此与表示学习直接相关
  • 为此,我们首先提出了一种直接特征匹配方法,该方法只关注优化学生的倒数第二层。其次,更重要的是,由于特征匹配没有考虑到手头的分类问题,我们提出了第二种方法,将表示学习和分类解耦,并利用教师预训练的分类器来训练学生的倒数第二层特征。
  • 特别是,对于相同的输入图像,我们希望教师和学生的特征在通过教师的分类器时产生相同的输出,这是通过简单的L2损失来实现的。我们的方法实现起来非常简单,训练起来也很简单,并且在包括不同(a)网络架构、(b)师生能力、©数据集和(d)域在内的大量实验设置中,始终优于以前最先进的方法。
    在这里插入图片描述

(4)Decoupled Knowledge Distillation(CVPR 2022)

论文地址:https://arxiv.org/pdf/2203.08679
代码地址:https://github.com/megviiresearch/mdistiller

主要内容:

  • 现有的蒸馏方法主要是基于从中间层提取深层特征,而忽略了logit蒸馏的重要性。为了给logit蒸馏的研究提供一个新的视角,我们将经典的KD损失重新表述为两部分,即目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)。
  • 我们实证研究并证明了两部分的效果:TCKD转移了关于训练样本“难度”的知识,而NCKD是logit蒸馏有效的突出原因。更重要的是,我们揭示了经典KD损失是一个耦合公式,它**(I)抑制了NCKD的有效性,(2)限制了平衡这两个部分的灵活性**。
  • 为了解决这些问题,我们提出了解耦知识蒸馏(DKD),使TCKD和NCKD更有效和灵活地发挥其作用。与基于复杂特征的方法相比,我们的DKD在CIFAR-100、ImageNet和MSCOCO数据集上的图像分类和目标检测任务取得了相当甚至更好的结果,并且具有更好的训练效率。本文证明了logit精馏的巨大潜力,希望对今后的研究有所帮助。
    在这里插入图片描述

(5)Self-supervised models are good teaching assistants for vision transformers(ICML 2022)

论文地址:https://proceedings.mlr.press/v162/wu22c/wu22c.pdf
代码地址:https://github.com/GlassyWu/SSTA

主要内容:

  • 在过去的一年里,Transformer在计算机视觉任务上取得了显著的进步。与CNN相比,Transformer通常需要借助蒸馏才能在中小型数据集上获得可比的结果。同时,最近的研究发现,当Transformer分别以监督和自监督的方式进行训练时,捕获的模式在定性和定量上都有很大的不同
  • 这些发现促使我们在常用的监督式教师之外,引入自我监督式助教(self-supervised teaching assistant, SSTA)来改善Transformer的性能。具体来说,我们提出一种头级知识蒸馏方法,即选择被监督教师和自监督助教中最重要的头,让学生模仿这两个头的注意力分布,从而使学生关注教师和助教所认为的token之间的关系
  • 大量的实验验证了SSTA的有效性,并证明了所提出的SSTA对被监督教师是一种很好的补偿。与此同时,一些由监督教师、自我监督助教和学生进行的多角度分析实验(如预测、形状偏差、鲁棒性和下游任务可转移性)是归纳性的,可能会对未来的研究产生启发
    在这里插入图片描述

(6)Masked autoencoders enable efficient knowledge distillers(CVPR 2023)

论文地址:
https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10203918
代码地址:https://github.com/UCSC-VLAA/DMAE

主要内容:

  • 本文研究了从预训练模型中提取知识的潜力,特别是掩码自编码器。我们的方法很简单:除了优化掩模输入上的像素重建损失外,我们还最小化了教师模型和学生模型的中间特征映射之间的距离
  • 这种设计导致了一个计算效率很高的知识蒸馏框架,假设I)只使用了一小部分可见的补丁子集,并且2)(繁琐的)教师模型只需要部分执行,即通过前几层前向传播输入,以获得中间特征映射
  • 与直接提取微调模型相比,提取预训练模型大大提高了下游性能。例如,通过将MAE预训练的VIT-L中的知识提取到ViT-B中,我们的方法达到了84.0%的ImageNet top-1准确率,比直接提取微调后的VIT-L的基线高出1.2%。
  • 更有趣的是,我们的方法即使具有极高的掩蔽比,也可以从教师模型中稳健地提取知识:例如,在95%的掩蔽比下,在蒸馏过程中只有10个斑块可见,我们的ViT-B具有竞争力,达到了83.6%的top-1 ImageNet精度;令人惊讶的是,它仍然可以通过仅使用四个可见补丁(98%掩蔽率)进行积极训练来确保82.4%的top-1 ImageNet准确率。
    在这里插入图片描述

总结

知识蒸馏是一种有效的模型压缩技术,通过让学生网络学习教师网络的软标签和特征表示,能够在不显著降低性能的情况下大幅减小模型复杂度。它广泛应用于模型压缩、加速推理、迁移学习等任务。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值