【Knowledge distillation: A good teacher is patient and consistent】

在计算机视觉方面,实现最先进性能的大型模型与实际应用中简单的模型之间的差距越来越大。在本文中,将解决这个问题,并显著地弥补这2种模型之间的差距。

在实证研究中,作者的目标不是一定要提出一种新的方法,而是努力确定一种稳健和有效的配置方案,使最先进的大模型在实践中能够得到应用。本文证明了在正确使用的情况下,知识蒸馏可以在不影响大模型性能的情况下减小它们的规模。作者还发现有某些隐式的设计选择可能会极大地影响蒸馏的有效性。

作者的主要贡献是明确地识别了这些设计选择。作者通过一项全面的实证研究来支持本文的发现,在广泛的视觉数据集上展示了很不错的结果,特别是,为ImageNet获得了最先进的ResNet-50模型,达到了82.8%的Top-1精度。

一、简介.

大型视觉模型目前主导着计算机视觉的许多领域。最新的图像分类、目标检测或语义分割模型都将模型的大小推到现代硬件允许的极限。尽管它们的性能令人印象深刻,但由于计算成本高,这些模型很少在实践中使用。

相反,实践者通常使用更小的模型,如ResNet-50或MobileNet等,这些模型运行起来代价更低。根据Tensorflow Hub的5个BiT的下载次数,最小的ResNet-50的下载次数明显多于较大的模型。因此,许多最近在视觉方面的改进并没有转化为现实世界的应用程序。

为了解决这个问题,本文将专注于以下任务:给定一个特定的应用程序和一个在它上性能很好的大模型,目标是在不影响性能的情况下将模型压缩到一个更小、更高效的模型体系结构。针对这个任务有2种广泛使用的范例:模型剪枝和知识蒸馏。

模型剪枝通过剥离大模型的各个部分来减少大模型的大小。这个过程在实践中可能会有限制性:首先,它不允许更改模型族,比如从ResNet到MobileNet。其次,可能存在依赖于架构的挑战;例如,如果大模型使用GN,修剪通道可能导致需要动态地重新分配通道组。

相反,作者专注于没有这些缺点的知识蒸馏方法。知识蒸馏背后的理念是“提炼”一个教师模型,在本文例子中,一个庞大而繁琐的模型或模型集合,制成一个小而高效的学生模型。这是通过强迫学生模型的预测与教师模型的预测相匹配,从而自然地允许模型家族的变化作为压缩的一部分。

图2

密切遵循Hinton的原始蒸馏配置,发现如果操作正确,它惊人地有效;如图1所示作者将蒸馏解释为匹配教师和学生实现的函数的任务。通过这种解释发现对模型压缩的知识蒸馏的2个关键原则。

  • 首先,教师和学生模型应该处理完全相同的输入图像,或者更具体地说,相同的裁剪和数据增强;
  • 其次,希望函数在大量的支撑点上匹配,以便更好地推广。

使用Mixup的变体,可以在原始图像流形外生成支撑点。考虑到这一点,通过实验证明,一致的图像视图、合适的数据增强和非常长的训练计划是通过知识蒸馏使模型压缩在实践中工作良好的关键。

尽管发现明显很简单,但有很多种原因可能会阻止研究人员(和从业者)做出建议的设计选择。

  • 首先,很容易预先计算教师对离线图像的激活量,以节省计算量,特别是对于非常大的教师模型;
  • 其次,知识蒸馏也通常用于不同的上下文(除了模型压缩),其中作者推荐不同甚至相反的设计选择;
  • 最后,知识蒸馏需要比较多的Epoch来达到最佳性能,比通常用于监督训练的Epoch要多得多。更糟糕的是,在常规时间的训练中看起来不理想的选择往往是最好的,反之亦然。

在本文的实证研究中,主要集中于压缩大型BiT-ResNet-152x2,它在ImageNet-21k数据集上预训练,并对感兴趣的相关数据集进行微调。在不影响精度的情况下,将其蒸馏为标准的ResNet-50架构(用GN代替BN)。还在ImageNet数据集上取得了非常强的结果:总共有9600个蒸馏周期,在ImageNet上得到了新的ResNet-50SOTA结果,达到了惊人的82.8%。这比原始的ResNet-50模型高出4.4%,比文献中最好的ResNet-50模型高出2.2%。

最后,作者还证明了本文的蒸馏方案在同时压缩和更改模型时也可以工作,例如BiT-ResNet架构到MobileNet架构。

二、实验配置

2.1 Datasets, metrics and evaluation protocol
在5个流行的图像分类数据集上进行了实验:flowers102,pets,food101,sun397和ILSVRC-2012(“ImageNet”)。这些数据集跨越了不同的图像分类场景;特别是,它们的类的数量不同,从37到1000个类,训练图像的总数从1020到1281167个不等。

2.2 Teacher and student models
在本文中,选择使用来自BiT的预训练教师模型,该模型提供了大量在ILSVRC-2012和ImageNet-21k数据集上预训练的ResNet模型,具有最先进的准确性。BiT-ResNets与标准ResNets唯一显著的区别是使用了GN层和权重标准化。

特别地专注于BiT-M-R152x2架构:在ImageNet-21k上预训练的BiT-ResNet-152x2(152层,“x2”表示宽度倍数)。该模型在各种视觉基准上都显示出了优异的性能,而且它仍然可以使用它进行广泛的消融研究。尽管如此,它的部署成本还是很昂贵的(它需要比标准ResNet-50多10倍的计算量),因此该模型的有效压缩具有实际的重要性。对于学生模型的架构,使用了一个BiT-ResNet-50变体,为了简洁起见,它被称为ResNet-50。

2.3 Distillation loss
这里使用教师模型的和学生模型的之间的KL散度作为一个蒸馏损失来预测类概率向量。对于原始数据集的硬标签,不使用任何额外的损失:
在这里插入图片描述
C是类别。这里还引入了一个温度参数T,用于在损失计算之前调整预测的softmax-probability分布的熵:
在这里插入图片描述
2.4 Training setup
为了优化,使用带有默认参数的Adam优化器训练模型。还使用了不带有Warm up的余弦学习率机制。

作者同时还为所有的实验使用了解耦的权重衰减机制。为了稳定训练,在梯度的全局l2范数上以1.0的阈值进行梯度裁剪。最后,除在ImageNet上训练的模型使用batch size为4096进行训练外,对其他所有实验都使用batch size为512。

本文的方案的另一个重要组成部分是Mixup数据增强策略。特别在“函数匹配”策略中中引入了一个Mixup变量,其中使用从[0,1]均匀抽样的较强的Mixup系数,这可以看作是最初提出的β分布抽样的一个极端情况。

作者还使用了““inception-style”的裁剪,然后将图像的大小调整为固定的正方形大小。此外,为了能够广泛的分析在计算上的可行(训练了数十万个模型),除了ImageNet实验,使用标准输入224×224分辨率,其他数据集均使用相对较低的输入分辨率,并将输入图像的大小调整为128×128大小。

三、模型蒸馏

3.1 “consistent and patient teacher”假说
在本节中,对介绍中提出的假设进行实验验证,如图1所示,当作为函数匹配时,蒸馏效果最好,即当学生和教师模型输入图像是一致视图时,通过mixup合成“filled”,当学生模型接受长时间的训练时(即“教师”很有耐心)。

为了确保假说的稳健性,作者对4个中小型数据集进行了非常彻底的分析,即Flowers102,Pets,Food101,Sun397进行了训练。
在这里插入图片描述
为了消除任何混杂因素,作者对每个精馏设定使用学习速率{0.0003,0.001,0.003,0.01}与权重衰减{1× 1 0 − 5 10^{-5} 105,3× 1 0 − 5 10^{-5} 105,1× 1 0 − 4 10^{-4} 104,3× 1 0 − 4 10^{-4} 104,1× 1 0 − 3 10^{-3} 103}以及蒸馏温度{1,2,5,10}的所有组合。

3.1.1.Importance of “consistent” teaching
首先,证明了一致性标准,即学生和教师看到相同的视图,是执行蒸馏的唯一方法,它可以在所有数据集上一致地达到学生模型的最佳表现。在本研究中,定义了多个蒸馏配置,它们对应于图1中所示的所有4个选项的实例化:

1. Fixed teacher

作者探索了几个选项,其中教师模型的预测是恒定的,为一个给定的图像。

最简单(也是最差的)的方法是fix/rs,即学生和老师的图像大小都被调整到224x224pixel。

fix/cc遵循一种更常见的方法,即教师使用固定的central crop,而学生使用random crop。

fix/ic_ens是一种重数据增强方法,教师模型的预测是1024种inception crops的平均值,我们验证了以提高教师的表现。该学生模型使用random crop。

2. Independent noise

用2种方式实例化了这种常见的策略:

ind/rc分别为教师和学生计算2种独立的random crop;

ind/ic则使用heavy inception crop。

3. Consistent teaching

在这种方法中,只对图像进行随机裁剪一次,要么是mild random cropping(same/rc),要么是heavy inception crop(same/ic),并使用相同的crop向学生和教师模型提供输入。

4. Function matching

这种方法扩展了consistent teaching,通过mixup扩展图像的输入,并再次为学生和教师模型提供一致的输入。为了简洁起见,将这种方法称为“FunMatch”。

3.1.2 Importance of “patient” teaching
人们可以将蒸馏解释为监督学习的一种变体,其中标签是由一个强大的教师模型提供的。当教师模型的预测计算为单一图像视图时,这一点尤其正确。这种方法继承了标准监督学习的所有问题,例如,严重的数据增强可能会扭曲实际的图像标签,而轻微的增强可能又会导致过拟合。

然而,如果将蒸馏解释为函数匹配,并且最重要的是,确保为学生和老师模型提供一致的输入,情况就会发生变化。在这种情况下,可以进行比较强的图像增强:即使图像视图过于扭曲,仍然会在匹配该输入上的相关函数方面取得进展。因此,可以通过增强来增加机会,通过做比较强的图像增强来避免过拟合,如果正确,可以优化很长一段时间,直到学生模型的函数接近教师模型的函数。
在这里插入图片描述
在图4中证实了作者的假设,对于每个数据集,显示了在训练最佳函数匹配学生模型时不同数量的训练Epoch的测试精度的变化。教师模型为一条红线,经过比在标准监督训练中使用的更多的Epoch后,最终总是能够达到。至关重要的是,即使优化了一百万个Epoch,也没有过拟合的迹象。

作者还训练和调整了另外2个Baseline以供参考:使用数据集原始硬标签从零开始训练ResNet-50,以及传输在ImageNet-21k上预训练的ResNet-50。对于这2个Baseline,侧重于调整学习率和权重衰减。使用原始标签从零开始训练的模型大大优于学生模型。

值得注意的是,相对较短的100个Epoch的训练结果比迁移Baseline差得多。总的来说,ResNet-50的学生模型持续地匹配ResNet-152x2教师模型。

CIFAR-10 Example

以Cifar-10数据集为例,验证蒸馏得到的resnet-50模型的准确率

weights_cifar10 = get_weights('BiT-M-R50x1-CIFAR10')
model = ResNetV2(ResNetV2.BLOCK_UNITS['r50'], width_factor=1, head_size=10)  # NOTE: No new head.
model.load_from(weights_cifar10)
model.to(device);
def eval_cifar10(model, bs=100, progressbar=True):
  loader_test = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers=2)

  model.eval()

  if progressbar is True:
    progressbar = display(progress(0, len(loader_test)), display_id=True)

  preds = []
  with torch.no_grad():
    for i, (x, t) in enumerate(loader_test):
      x, t = x.to(device), t.numpy()
      logits = model(x)
      _, y = torch.max(logits.data, 1)
      preds.extend(y.cpu().numpy() == t)
      progressbar.update(progress(i+1, len(loader_test)))

  return np.mean(preds)
print("Expected: 97.61%")
print(f"Accuracy: {eval_cifar10(model):.2%}")

评估预训练模型,输出如下:

Expected: 97.61%

Accuracy: 97.62%

找到索引以创建5个镜头的CIFAR10变体

preprocess_tiny = tv.transforms.Compose([tv.transforms.CenterCrop((2, 2)), tv.transforms.ToTensor()])
trainset_tiny = tv.datasets.CIFAR10(root='./data', train=True, download=True, transform=preprocess_tiny)
loader = torch.utils.data.DataLoader(trainset_tiny, batch_size=50000, shuffle=False, num_workers=2)
images, labels = iter(loader).next()
indices = {cls: np.random.choice(np.where(labels.numpy() == cls)[0], 5, replace=False) for cls in range(10)}
print(indices)
fig = plt.figure(figsize=(10, 4))
ig = ImageGrid(fig, 111, (5, 10))
for c, cls in enumerate(indices):
  for r, i in enumerate(indices[cls]):
    img, _ = trainset[i]
    ax = ig.axes_column[c][r]
    ax.imshow((img.numpy().transpose([1, 2, 0]) * 127.5 + 127.5).astype(np.uint8))
    ax.set_axis_off()
fig.suptitle('The whole 5-shot CIFAR10 dataset');
train_5shot = torch.utils.data.Subset(trainset, indices=[i for v in indices.values() for i in v])
len(train_5shot)

输出如下

50

微调BiT-M(resnet-50)在这个5-shot CIFAR10变体上

model = ResNetV2(ResNetV2.BLOCK_UNITS['r50'], width_factor=1, head_size=10, zero_head=True)
model.load_from(weights)
model.to(device);
sampler = torch.utils.data.RandomSampler(train_5shot, replacement=True, num_samples=256)
loader_train = torch.utils.data.DataLoader(train_5shot, batch_size=256, num_workers=2, sampler=sampler)
crit = nn.CrossEntropyLoss()
opti = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
model.train();
S = 500
def schedule(s):
  step_lr = stairs(s, 3e-3, 200, 3e-4, 300, 3e-5, 400, 3e-6, S, None)
  return rampup(s, 100, step_lr)

pb_train = display(progress(0, S), display_id=True)
pb_test = display(progress(0, 100), display_id=True)
losses = [[]]
accus_train = [[]]
accus_test = []

steps_per_iter = 512 // loader_train.batch_size

while len(losses) < S:
  for x, t in loader_train:
    x, t = x.to(device), t.to(device)

    logits = model(x)
    loss = crit(logits, t) / steps_per_iter
    loss.backward()
    losses[-1].append(loss.item())

    with torch.no_grad():
      accus_train[-1].extend(torch.max(logits, dim=1)[1].cpu().numpy() == t.cpu().numpy())

    if len(losses[-1]) == steps_per_iter:
      losses[-1] = sum(losses[-1])
      losses.append([])
      accus_train[-1] = np.mean(accus_train[-1])
      accus_train.append([])

      # Update learning-rate according to schedule, and stop if necessary
      lr = schedule(len(losses) - 1)
      for param_group in opti.param_groups:
        param_group['lr'] = lr

      opti.step()
      opti.zero_grad()

      pb_train.update(progress(len(losses) - 1, S))
      print(f'\r[Step {len(losses) - 1}] loss={losses[-2]:.2e} '
            f'train accu={accus_train[-2]:.2%} '
            f'test accu={accus_test[-1] if accus_test else 0:.2%} '
            f'(lr={lr:g})', end='', flush=True)

      if len(losses) % 25 == 0:
        accus_test.append(eval_cifar10(model, progressbar=pb_test))
        model.train()

得到的损失函数、训练准确率和测试准确率输出如下

[Step 499] loss=2.23e-05 train accu=100.00% test accu=86.41% (lr=3e-06)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))
ax1.plot(losses[:-1])
ax1.set_yscale('log')
ax1.set_title('loss')
ax2.plot(accus_train[:-1])
ax2.set_title('training accuracy')
ax3.plot(np.arange(25, 501, 25), accus_test)
ax3.set_title('test accuracy');

得到的损失函数、训练准确率和测试准确率图像输出如下
在这里插入图片描述
参考文章:
让ResNet-50精度高达82.8%!ViT原作者的知识蒸馏新作 | CVPR 2022 Oral

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值