CVPR 2023 | 大模型流行之下,SN-Net给出一份独特的答卷

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

点击进入—>【计算机视觉】微信技术交流群

作者:盘子正 |(已授权转载)编辑:CVer

https://zhuanlan.zhihu.com/p/611257510

写在前面:本文介绍我们组在CVPR 2023的工作:Stitchable Neural Networks,下文简称SN-Net。一种全新的模型部署方法,利用现有的model family直接做少量epoch finetune就可以得到大量插值般存在的子网络,运行时任意切换网络结构满足不同resource constraint。

e17395a88eaa4cb17a688742665ca974.png

Paper: https://arxiv.org/abs/2302.06586

Code: https://github.com/ziplab/SN-Net

背景

去年一次组会上,在和导师们讨论未来的research方向的时候,偶然聊到一个问题:

视频网站的视频播放会自动根据网络带宽调整画质,如网速好的时候到4K,网速差就720P甚至更低。那同一个神经网络能不能随时根据计算资源的变化调整推理速度?

从2012的AlexNet到2023年火出圈的ChatGPT, AI/ML这一社区在十年间少说已经训练了上百万个模型。截至这篇文章写作时,HuggingFace上可以直接下载的模型就有14万个,涵盖各个模态和任务。每个模型各司其职,用自己在训练中学到的知识去处理某一种场景,互不叨扰。

a89b2f39990eff4136e14e7a1adcd8a0.jpeg

模型虽然越来越多,但是资源浪费也越来越严重。训练一个模型的成本很高,尤其是大模型训练,耗费数个节点和几天的算力才能得到一个好权重,但最后却受限于应用场景只能重新调整结构,然后再重新训练,如网络backbone设计中通常会有不同scale来满足不同的推理速度要求: ResNet-18/50/101,DeiT-Ti/S/B,Swin-Ti/S/B等等。

传统方法当然能加速模型推理,如pruning,distillation,quantization。但问题是这些方法一次大都只能针对一个模型,一个资源场景。我们也可以用NAS搜出来若干个子网络来满足不同推理速度需求,即使如此,NAS中训练一个Supernet的成本也是巨大的,典型的如OFA和BigNAS,花费上千GPU hours才得到一个好网络,资源消耗巨大。

看着huggingface上这么大的model zoo,我们不禁想,整个社区花了大量时间,金钱和人力资源去训练网络,得到了这么多的 pretrained model,但是能不能有效利用起来?况且这些模型已经训练好了,当需要他们的时候,能不能用少量计算资源就可以满足目标场景?

对这一问题的思考也是随着模型被工业界越推越大引出的。几年前一张1080就能跑完的实验,现在8张卡都很难train得动一个model,特别是Transformer出来之后。最新的ViT已经scale到22B,BAAI的EVA也把ViT扩展到了1B的参数级别。留给小组的空间越来越小,在资源有限(缺卡)的场景下,我们需要寻求新的突破方向。

Stitchable Neural Networks

Industry和Academia所关注的问题可以有些区别。既然大模型不是所有人能做得起的,那我们不如去利用好已有的pretrained model。现在我们有了一组训练好的model family,如DeiT-Tiny/Small/Base。不同模型有不同大小,推理速度,显存占用。那么能不能利用这些已有的weights和结构快速得到一批新网络来满足不同的资源场景?

我们在CVPR 2023最新的工作Stitchable Neural Network (SN-Net) 给出了一个非常具有潜力的方案。

2a3e22f989040c2c426d593ef50cebd1.jpeg

SN-Net的主要思想是:在一组已经训练好的model family中插入若干个stitching layer (即1x1 conv), 使得forward时activation可以在模型间的不同位置游走。当模型在不同位置缝合的时候,一个个新网络结构就出来了!!!

此时,我们把原先model family中的网络叫做anchors,缝合出来的新网络叫做stitches。单个SN-Net可以cover众多FLOPs-accuracy的trade-off,如在基于Swin的实验中,一个SN-Net的可以挑战timm中200个独立的模型,整个实验不过是50 epochs,八张V100上训练不到一天。

706e63d50d80abf7c49ebf46c215227a.jpeg

下面会介绍详细的做法,以及我们当时方法设计时候的考虑。想直接看效果的朋友可以移步最后的结果展示。

1. 模型这么多,怎么去选择

这里主要考虑了几个地方:

  • 不同模型结构在网络中各层学习到的representation会有较大差别,缝合出来的网络不一定保证较好的performance。

  • 不同数据集学到的东西差别也很大,为了保证性能最好保持在相同pretrained的dataset下。

  • 不同网络的实现和训练方式有差别,工程上很难权衡超参和data augmentation的选择。而同一个结构通常在一个repo里,更容易实现。

因此,我们初步关注在相同dataset上训练好的model family上, 即结构相似,但是模型scale不一样,如DeiT-Ti/S/B。

不同family能不能缝合?也能,我们paper里有展示结果,但是工程上会比较麻烦,需要combine不同repo并且权衡超参。

2. 怎么去做缝合?

model stitching在原先工作中大都是以研究representation similarity的形式呈现的,如

  • Lenc, Karel, and Andrea Vedaldi. "Understanding image representations by measuring their equivariance and equivalence." CVPR 2015.

  • Kornblith, Simon, et al. "Similarity of neural network representations revisited." ICML, 2019.

  • Csiszárik, Adrián, et al. "Similarity and matching of neural network representations." NeurIPS 2021.

总结过去这些工作:同一个网络,用不同seed训练之后可以在某些位置缝合起来,此时性能不会掉的很离谱。后续的研究发现结构不一样的网络甚至也能缝合。

而stitching能够work在于,假设前一个网络出来的feature map属于activation 空间A,而另一个网络在此位置的输入feature map属于activation空间B,那么stitching layer做的事情就是把feature map从A空间映射到B空间,使得此时的feature map能模拟下一网络在这个位置的输入。

当网络是已经是pretrained,那么stitching这一过程完全可以formulate成一个求解least squares的问题。也就是说stitching layer这个weights的matrix是可以直接求出来的 (参考Csiszárik, Adrián, et al 这篇)。所以此时求解出来的matrix可以天然作为stitching layer的初始化。

3. 缝合方向的设定

7c9e620149c7f7fb11253fa949b3ea97.jpeg

现在我们有一个大模型:性能好但是推理速度慢,还有一个小模型:性能差点但是推理速度快。我们怎么决定谁stitch到谁呢?我们主要考虑了两个方面:

  • 参考当前backbone设计的惯例,随着网络不断深入,channel dimension是在不断增大的。Fast-to-Slow这方向比较符合常见的网络设计。

  • 实验验证Fast-to-Slow得到的curve要比Slow-to-Fast要smooth一点,详见论文。

所以目前SN-Net在方向上是从小模型缝合到大模型。同时我们提出一个constraint: nearest stitching,限制stitching只在复杂度(FLOPs)相邻的两个anchor之间。如补充材料中的Figure 10所示,以DeiT-Ti/S/B为例,我们的方法目前限制在(a), (b)两个case。

4e90256fd1eeb4afddaba60fb2bf756e.jpeg

这个限制是因为我们发现anchor的gap比较大的时候,缝合出来的网络并不在一个optimal的区间。实验部分也证明直接stitch DeiT-Ti和DeiT-B效果不如中间加一个DeiT-S。

4. 怎么配置Stitching Layer

7e256842032cd353df1b62a5203d5605.jpeg

网络设计地千奇百怪,怎么去缝合是个问题。

我们以DeiT为例,在相同depth的缝合实验上采取了Paired Stitching这种策略。这种策略的启发来自于过去一些工作发现:相邻layer之间的representation是有较高的相似度的。所以我们选择在DeiT得相邻blocks中share同一个stitching layer,如滑窗一般进行stitching。

share的情况下,原先的初始化方法就是简单地对不同solution得到的matrix做一个average。选择share stitching layer还有其他好处,如减少过多stitching layer带来的参数量,同时扩大缝合出来的结构数量,即扩大stitching space。

另外一种情况是两个模型的depth不一样,小模型一般比较浅,block的数量要比大模型少。比如Swin-Ti的第三个stage只有6个block,而Swin-S在第三个stage有18个block。此时我们进行Unpaired Stitching,每个小模型的block都stitch到大模型的若干个block中。这样两个case就都解决了

5. SN-Net能缝出来多少网络?

这个由多种因素决定。

  • 看选择的model family,即anchors的depth。显然anchor越深,那么能stitch的位置就越多,新网络结构也会更多。

  • 相同depth下看stitching时sliding window的设置。

  • 不加nearest stitching的时候得到的网络更多 (DeiT上的实验是十倍的差距,71 vs. 731)。但是此时不optimal。后续潜力尚待挖掘。

对比NAS中 10^20 级别的search space, SN-Net在基于同一组model family得到的网络数量是有限的。但有一点不得不提,纵使search space再大,真正需要的时候也只是用pareto frontier上的网络结构,而SN-Net缝合出来的网络几乎天然落在pareto frontier上,同时部署的时候完全可以直接查表,几乎没有什么search cost。

另外一点是,SN-Net的潜力在于整个pretrained model zoo。有多少model familiy,就有多少潜在的SN-Net变种。这是NAS的单一supernet所不能比拟的。这意味着我们可以轻易缝合已有的model family达到NAS耗费大量计算资源搜出来的网络性能,比如简单缝合两个LeViT就可以用更低的FLOPs(977M vs. 1040M) 达到媲美BigNASModel-XL的性能(80.7% vs. 80.9%),如下图所示

62ee2dad8ee1e5e05f4e9019207ca9b9.jpeg

6. 简单的训练策略

训练SN-Net尤为简单。先提前把所有需要训练的stitches定义好,训练中每次iteration都随机sample出来一个stitch,后面和正常的训练一样进行loss回传,梯度下降。为了进一步提升stitches的性能,我们初步实验同时采用了RegNetY-160作为teacher model去做distillation。

0691864b8cb178d8102ab38146e07dea.jpeg

结果展示

为了验证Joint Training和原有网络从头train的差距,我们选择了若干个和stitches相同的网络结构,然后在ImageNet上训满300 epochs。从下表可以看到,对比用了大量计算资源训练出来的网络,SN-Net利用已有的DeiT family只用50个epoch就可以得到比肩甚至更好的性能。同时整个网络只要118.4M的参数,而这71个stitches的总量如果单独训练需要2630M,耗费 71 × 300 epochs,和SN-Net比是22倍的差距。

3e97a5f9a9d08d6320d8a82ab5947fdf.jpeg

基于DeiT和Swin Transformer, 我们验证了缝合plain ViT和hierarchical ViT的可行性。性能曲线如在anchors中进行插值一般。

93c9dfe04f0cddfc30611d098275c16d.jpeg

值得一提的是,图中不同点所表示的子网络,即stitch,是可以在运行时随时切换的。这意味着网络在runtime完全可以依靠查表进行瞬时推理速度调整。这个是诸多网络无法实现的,但颇具现实意义。比如现在很多手机都有省电模式,一旦进行power saving, 手机掉帧,系统运行速度变慢,而此时neural network也可以调整推理速度,做一个speed-accuracy的trade-off。

我们当然也尝试了stitch cnn,甚至不同的family,结果非常promising。

52c6d08ce371d9cedea19ecf4b24c5d8.jpeg

更多实验内容和分析请移步我们的arxiv论文:Stitchable Neural Networks

SN-Net的可扩展空间

SN-Net生于large model zoo的时代。我们初版方法给出了一个最简单的baseline,相信未来有很大的扩展空间,比如

  1. 当前的训练策略比较简单,每次iteration sample出来一个stitch,但是当stitches特别多的时候,可能导致某些stitch训练的不够充分,除非增加训练时间。所以训练策略上可以继续改进。

  2. anchor的performance会比之前下降一些,虽然不大。直觉上,在joint training过程中,anchor为了保证众多stitches的性能在自身weights上做了一些trade-off。目前补充材料里发现finetune更多epoch可以把这部分损失补回来。

  3. 不用nearest stitching可以明显扩大space,但此时大部分网络不在pareto frontier上,未来可以结合训练策略进行改进,或者在其他地方发现advantage。

  4. 未来能否有个更好方法和统一的框架去缝合任意网络。到那时,整个model zoo就像积木一样,可操作空间更大,玩法更多,这一点NUS的Xingyi Yang之前有尝试,参考Deep Model Reassembly.

更多探索就留给future work了。代码已经开源至https://github.com/ziplab/SN-Net,硬件要求十分友好,50个epoch (用8卡V100大约半天时间) 就可以复现结果。欢迎有兴趣的同学进行尝试!

个人主页:https://zizhengpan.github.io/

实验室主页:https://ziplab.github.io/

文中如有错误,欢迎指出,同时欢迎各位进行学术交流~

点击进入—>【计算机视觉】微信技术交流群

最新CVPP 2023论文和代码下载

 
 

后台回复:CVPR2023,即可下载CVPR 2023论文和代码开源的论文合集

后台回复:Transformer综述,即可下载最新的3篇Transformer综述PDF

目标检测和Transformer交流群成立
扫描下方二维码,或者添加微信:CVer333,即可添加CVer小助手微信,便可申请加入CVer-目标检测或者Transformer 微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer等。
一定要备注:研究方向+地点+学校/公司+昵称(如目标检测或者Transformer+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

▲扫码或加微信号: CVer333,进交流群
CVer学术交流群(知识星球)来了!想要了解最新最快最好的CV/DL/ML论文速递、优质开源项目、学习教程和实战训练等资料,欢迎扫描下方二维码,加入CVer学术交流群,已汇集数千人!

▲扫码进群
▲点击上方卡片,关注CVer公众号
整理不易,请点赞和在看
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值