这篇论文提出了Segment anything的一个轻量化小模型,使SAM能够更高效地部署在移动端上,23年12月挂到了arxiv上,具体方法是知识蒸馏。
这篇论文使用方法分为三部分,第一部分是只对图像编码器的知识蒸馏,第二部分是循环prompt蒸馏,第三部分是可选的粒度先验模块。前两部分就像下图中说明的一样,每个部分对应一个蒸馏loss。
对于第一部分,只对编码器的蒸馏,来自于MobileSAM这篇论文,蒸馏损失就是教师SAM图像编码器和学生SAM编码器之间的均方误差。作者一开始只使用这部分损失,但是模型效果并不好,于是又考虑加入下面的prompt循环蒸馏损失。
其实prompt循环蒸馏也比较简单,作者在发现单纯对图像编码器蒸馏之后效果不好就打算把整个网络一起蒸馏了,包括了prompt编码器和掩码解码器。此时就需要考虑prompt要怎么输出了,这里作者对于初始prompt选择的是SA-1B数据集中随机采样得到的点或框prompt。把prompt embedding和图像embedding输入解码器就可以拿到掩模,这时候取两个模型输出的掩模计算loss。
要注意第二部分是循环prompt蒸馏,这里面是有一个循环在的,也就是根据初始prompt要去循环构建一串prompt。构造原则可以看左图,这里图上显示的是比较清楚的。在已有prompt下教师模型和学生模型都会输出一个掩模,将教师掩模视为ground truth,那么学生掩模没覆盖到的教师掩模部分就是False Negative,学生掩模多覆盖的部分就是False Positive。下一个Positive的点prompt就可以在FN部分任取,下一个Negative的点prompt可以在FP部分任取。按照生成方法可以根据初始的一个prompt生成一串prompt,这也就是循环prompt蒸馏的过程了。
右图是作者对循环prompt蒸馏做的消融实验,从这也能看出来其实第一个prompt可能是框或者点,但之后的prompt都是点了。
接下来就是最后一个部分粒度先验模块,这个模块可加可不加,因为加上以后会带来更大的计算量,不过也可以相应地提升模型分割准确率。
作者设计这个模块是考虑到下面左图这种情况,由于SA-1B数据集是多粒度分割的,所以对于一个单独的点prompt很难确定它的输出粒度,但是通过框prompt就很容易确定分割粒度了。所以这里作者提出了一个轻量的区域候选网络RPN,由特征金字塔FPN和检测头组成,训练时在COCO数据集上训练以提前捕获物体粒度。在推理的时候根据已有的点prompt生成一个额外补充框prompt。具体方法就是找到已有点的k近邻的候选框,而且这个距离还要和框的置信度加权,然后将这k个候选框合并成一个框prompt。
右边是作者做的消融实现,可以看到加上RPN后mIoU提升了,但是帧率FPS(每秒处理的图像帧数)下降了。
接下来是模型的训练过程,这里训练分为三阶段,第一阶段就是用纯图像编码器的蒸馏损失来训练,训练数据集来自SA-1B的1%。第二阶段在相同的数据集上,使用循环prompt蒸馏损失来训练整个模型。第三阶段可有可无,是专门用来训练RPN模块的,这里是冻结除了RPN之外的所有模块,然后在某些特定数据集上训练(例如COCO),目的是捕捉这些数据集上的粒度先验知识。
之后就是模型的性能展示,下面两张图展示的信息就是EdgeSAM在分割准确率上仅次于原始的教师模型SAM,同时参数量、帧率优于所有模型。