这里是一种全新的模型部署方法SN-Net,利用现有的model family直接做少量epoch finetune就可以得到大量插值般存在的子网络,运行时任意切换网络结构满足不同resource constraint。

Paper: https://arxiv.org/abs/2302.06586

Code: https://github.com/ziplab/SN-Net

视频网站的视频播放会自动根据网络带宽调整画质,如网速好的时候到4K,网速差就720P甚至更低。那同一个神经网络能不能随时根据计算资源的变化调整推理速度?

从2012的AlexNet到2023年火出圈的ChatGPT, AI/ML这一社区在十年间少说已经训练了上百万个模型。截至这篇文章写作时,HuggingFace上可以直接下载的模型就有14万个,涵盖各个模态和任务。每个模型各司其职,用自己在训练中学到的知识去处理某一种场景,互不叨扰。

模型虽然越来越多,但是资源浪费也越来越严重。训练一个模型的成本很高,尤其是大模型训练,耗费数个节点和几天的算力才能得到一个好权重,但最后却受限于应用场景只能重新调整结构,然后再重新训练,如网络backbone设计中通常会有不同scale来满足不同的推理速度要求: ResNet-18/50/101,DeiT-Ti/S/B,Swin-Ti/S/B等等。

传统方法当然能加速模型推理,如pruning,distillation,quantization。但问题是这些方法一次大都只能针对一个模型,一个资源场景。我们也可以用NAS搜出来若干个子网络来满足不同推理速度需求,即使如此,NAS中训练一个Supernet的成本也是巨大的,典型的如OFA (https://arxiv.org/abs/1908.09791)和BigNAS (https://arxiv.org/abs/2003.11142),花费上千GPU hours才得到一个好网络,资源消耗巨大。

看着huggingface上这么大的model zoo,我们不禁想,整个社区花了大量时间,金钱和人力资源去训练网络,得到了这么多的 pretrained model,但是能不能有效利用起来? 况且这些模型已经训练好了,当需要他们的时候,能不能用少量计算资源就可以满足目标场景?

对这一问题的思考也是随着模型被工业界越推越大引出的。几年前一张1080就能跑完的实验,现在8张卡都很难train得动一个model,特别是Transformer出来之后。最新的ViT已经scale到22B,BAAI的 EVA (https://arxiv.org/abs/2211.07636)也把ViT扩展到了1B的参数级别。留给小组的空间越来越小,在资源有限(缺卡)的场景下,我们需要寻求新的突破方向。

Stitchable Neural Networks

Industry和Academia所关注的问题可以有些区别。既然大模型不是所有人能做得起的,那我们不如去利用好已有的pretrained model。现在我们有了一组训练好的model family,如DeiT-Tiny/Small/Base。不同模型有不同大小,推理速度,显存占用。那么能不能利用这些已有的weights和结构快速得到一批新网络来满足不同的资源场景?

我们在CVPR 2023最新的工作Stitchable Neural Network (SN-Net) 给出了一个非常具有潜力的方案。

SN-Net_插值

SN-Net的主要思想是:在一组已经训练好的model family中插入若干个stitching layer (即1x1 conv), 使得forward时activation可以在模型间的不同位置游走。当模型在不同位置缝合的时候,一个个新网络结构就出来了!!!

此时,我们把原先model family中的网络叫做anchors,缝合出来的新网络叫做stitches。单个SN-Net可以cover众多FLOPs-accuracy的trade-off,如在基于Swin的实验中,一个SN-Net的可以挑战timm中200个独立的模型,整个实验不过是50 epochs,八张V100上训练不到一天。

SN-Net_人工智能_02

下面会介绍详细的做法,以及我们当时方法设计时候的考虑。想直接看效果的朋友可以移步最后的结果展示。

1. 模型这么多,怎么去选择

这里主要考虑了几个地方:

  • 不同模型结构在网络中各层学习到的representation会有较大差别,缝合出来的网络不一定保证较好的performance。
  • 不同数据集学到的东西差别也很大,为了保证性能最好保持在相同pretrained的dataset下。
  • 不同网络的实现和训练方式有差别,工程上很难权衡超参和data augmentation的选择。而同一个结构通常在一个repo里,更容易实现。

因此,我们初步关注在相同dataset上训练好的model family上, 即结构相似,但是模型scale不一样,如DeiT-Ti/S/B。

不同family能不能缝合?也能,我们paper里有展示结果,但是工程上会比较麻烦,需要combine不同repo并且权衡超参。

2. 怎么去做缝合?

model stitching在原先工作中大都是以研究representation similarity的形式呈现的,如

  • Lenc, Karel, and Andrea Vedaldi. "Understanding image representations by measuring their equivariance and equivalence." CVPR 2015.
  • Kornblith, Simon, et al. "Similarity of neural network representations revisited." ICML, 2019.
  • Csiszárik, Adrián, et al. "Similarity and matching of neural network representations." NeurIPS 2021.

总结过去这些工作:同一个网络,用不同seed训练之后可以在某些位置缝合起来,此时性能不会掉的很离谱。后续的研究发现结构不一样的网络甚至也能缝合。

而stitching能够work在于,假设前一个网络出来的feature map属于activation 空间A,而另一个网络在此位置的输入feature map属于activation空间B,那么stitching layer做的事情就是把feature map从A空间映射到B空间,使得此时的feature map能模拟下一网络在这个位置的输入。

当网络是已经是pretrained,那么stitching这一过程完全可以formulate成一个求解least squares的问题。也就是说stitching layer这个weights的matrix是可以直接求出来的 (参考 Csiszárik, Adrián, et al (https://arxiv.org/abs/2110.14633) 这篇)。所以此时求解出来的matrix可以天然作为stitching layer的初始化。

3. 缝合方向的设定

SN-Net_插值_03

现在我们有一个大模型:性能好但是推理速度慢,还有一个小模型:性能差点但是推理速度快。我们怎么决定谁stitch到谁呢?我们主要考虑了两个方面:

  • 参考当前backbone设计的惯例,随着网络不断深入,channel dimension是在不断增大的。Fast-to-Slow这方向比较符合常见的网络设计。
  • 实验验证Fast-to-Slow得到的curve要比Slow-to-Fast要smooth一点,详见论文。

所以目前SN-Net在方向上是从小模型缝合到大模型。同时我们提出一个constraint: nearest stitching,限制stitching只在复杂度(FLOPs)相邻的两个anchor之间。如补充材料中的Figure 10所示,以DeiT-Ti/S/B为例,我们的方法目前限制在(a), (b)两个case。

SN-Net_Small_04

这个限制是因为我们发现anchor的gap比较大的时候,缝合出来的网络并不在一个optimal的区间。实验部分也证明直接stitch DeiT-Ti和DeiT-B效果不如中间加一个DeiT-S。

4. 怎么配置Stitching Layer

SN-Net_人工智能_05

我们以DeiT为例,在相同depth的缝合实验上采取了Paired Stitching这种策略。这种策略的启发来自于过去一些工作发现:相邻layer之间的representation是有较高的相似度的。所以我们选择在DeiT得相邻blocks中share同一个stitching layer,如滑窗一般进行stitching。

share的情况下,原先的初始化方法就是简单地对不同solution得到的matrix做一个average。选择share stitching layer还有其他好处,如减少过多stitching layer带来的参数量,同时扩大缝合出来的结构数量,即扩大stitching space。

另外一种情况是两个模型的depth不一样,小模型一般比较浅,block的数量要比大模型少。比如Swin-Ti的第三个stage只有6个block,而Swin-S在第三个stage有18个block。此时我们进行Unpaired Stitching,每个小模型的block都stitch到大模型的若干个block中。这样两个case就都解决了

5. SN-Net能缝出来多少网络?

这个由多种因素决定。

  • 看选择的model family,即anchors的depth。显然anchor越深,那么能stitch的位置就越多,新网络结构也会更多。
  • 相同depth下看stitching时sliding window的设置。
  • 不加nearest stitching的时候得到的网络更多 (DeiT上的实验是十倍的差距,71 vs. 731)。但是此时不optimal。后续潜力尚待挖掘。

对比NAS中 10^{20}10^{20} 级别的search space, SN-Net在基于同一组model family得到的网络数量是有限的。但有一点不得不提,纵使search space再大,真正需要的时候也只是用pareto frontier上的网络结构,而SN-Net缝合出来的网络几乎天然落在pareto frontier上,同时部署的时候完全可以直接查表,几乎没有什么search cost。

另外一点是,SN-Net的潜力在于整个pretrained model zoo。有多少model familiy,就有多少潜在的SN-Net变种。这是NAS的单一supernet所不能比拟的。这意味着我们可以轻易缝合已有的model family达到NAS耗费大量计算资源搜出来的网络性能,比如简单缝合两个LeViT (https://arxiv.org/abs/2104.01136)就可以用更低的FLOPs(977M vs. 1040M) 达到媲美 BigNASModel-XL (https://arxiv.org/abs/2003.11142)的性能(80.7% vs. 80.9%),如下图所示

SN-Net_插值_06

6. 简单的训练策略

训练SN-Net尤为简单。先提前把所有需要训练的stitches定义好,训练中每次iteration都随机sample出来一个stitch,后面和正常的训练一样进行loss回传,梯度下降。为了进一步提升stitches的性能,我们初步实验同时采用了RegNetY-160作为teacher model去做distillation。

SN-Net_github_07

结果展示

为了验证Joint Training和原有网络从头train的差距,我们选择了若干个和stitches相同的网络结构,然后在ImageNet上训满300 epochs。从下表可以看到,对比用了大量计算资源训练出来的网络,SN-Net利用已有的DeiT family只用50个epoch就可以得到比肩甚至更好的性能。同时整个网络只要118.4M的参数,而这71个stitches的总量如果单独训练需要2630M,耗费 71 × 300 epochs,和SN-Net比是22倍的差距。

SN-Net_插值_08

基于DeiT和Swin Transformer, 我们验证了缝合plain ViT和hierarchical ViT的可行性。性能曲线如在anchors中进行插值一般。 

SN-Net_插值_09

值得一提的是,图中不同点所表示的子网络,即stitch,是可以在运行时随时切换的。这意味着网络在runtime完全可以依靠查表进行瞬时推理速度调整。这个是诸多网络无法实现的,但颇具现实意义。比如现在很多手机都有省电模式,一旦进行power saving, 手机掉帧,系统运行速度变慢,而此时neural network也可以调整推理速度,做一个speed-accuracy的trade-off。

我们当然也尝试了stitch cnn,甚至不同的family,结果非常promising。

SN-Net_Small_10

更多实验内容和分析请移步我们的arxiv论文:Stitchable Neural Networks

SN-Net的可扩展空间

SN-Net生于large model zoo的时代。我们初版方法给出了一个最简单的baseline,相信未来有很大的扩展空间,比如

  1. 当前的训练策略比较简单,每次iteration sample出来一个stitch,但是当stitches特别多的时候,可能导致某些stitch训练的不够充分,除非增加训练时间。所以训练策略上可以继续改进。
  2. anchor的performance会比之前下降一些,虽然不大。直觉上,在joint training过程中,anchor为了保证众多stitches的性能在自身weights上做了一些trade-off。目前补充材料里发现finetune更多epoch可以把这部分损失补回来。
  3. 不用nearest stitching可以明显扩大space,但此时大部分网络不在pareto frontier上,未来可以结合训练策略进行改进,或者在其他地方发现advantage。
  4. 未来能否有个更好方法和统一的框架去缝合任意网络。到那时,整个model zoo就像积木一样,可操作空间更大,玩法更多,这一点NUS的Xingyi Yang (https://adamdad.github.io/)之前有尝试,参考Deep Model Reassembly (https://arxiv.org/abs/2210.17409).

更多探索就留给future work了。代码已经开源至https://github.com/ziplab/SN-Net,硬件要求十分友好,50个epoch (用8卡V100大约半天时间) 就可以复现结果。