深度学习笔记

目录

摘要

Abstract

SAM模型

整体设计

image encoder

prompt encoder

mask decoder

总结


摘要

Segement Anything Model一经发布就获得了无数关注,当然更多人关注的是他庞大的数据集以及生成这个数据集所用的Data Engine,构建数据集的贡献诚然巨大,但是SAM本身结构的设计也是有相当的想法。尽管模型本身说不上十分复杂,但是也值得认真学习。

Abstract

The Segment Anything Model has garnered countless attention since its release, with more focus on its massive dataset and the Data Engine used to generate it. While the contribution of building the dataset is significant, the design of SAM's own structure also has considerable ideas. Although the model itself cannot be said to be very complex, it is still worth studying seriously.

SAM模型

整体设计

 

模型整体上包含三个大模块,image encoder,prompt encoder和mask decoder。

  • image encoder旨在映射待分割的图像到图像特征空间。
  • prompt encoder则是负责映射输入的prompt到prompt的特征空间,这里有一点要提就是作者定义了sparse和dense两种prompt,其中sparse prompt比较好理解,就是指demo中我们可以输入的点,目标框或者是描述目标的text,而dense prompt在目前的线上demo中体验不到,paper中也只说它对应的是mask类型的prompt,从代码里看应该是训练时候用的比较多,一般是上一次迭代预测出的一个粗分割的mask,粗略指出待分割的目标区域。
  • mask decoder的意义从功能上说有两个,一是整合image encoder和prompt encoder分别输出的两个embedding,然后从这个embedding的feature map解码出最终的分割mask。

SAM的一个我个人认为比较新颖的点子是它从interactive segmentation引申出了一个新的任务类型,叫做promptable segmentation。从他的模型中也能看出,输入的prompt是模型在输出最终mask的关键指导信息,这也是为什么我发现目前的SAM模型在处理一些专业领域图像(比如我自己从事的医学图像分割)时,直接使用他的segment everything功能,也就是无prompt进行分割时效果不好的原因。

另一个要搞清楚的问题是在进行有prompt的分割时,实际上实现的是一个二分类的分割任务,模型要解决的问题是根据我们选择的点的特征,从图像(背景)中分割出这个点所在的目标物体(前景),它本质上并不关心这个目标物体是个什么东西。滑稽一点来说,整个过程实际上有点类似photoshop里魔棒的功能,adobe倒是可以考虑把这个模型整合进ps里提升一些性能。

下面的部分我们分别分析一下每个模块的结构和意义。

image encoder

正如上面所说,image encoder的作用是把图像映射到特征空间,那么,如paper里所说,本质上这个encoder可以是任何网络结构,作者这里使用的是微调的detectron的ViT,当然它也可以被改成传统的卷积结构,非常合理。

这里的ViT结构也并不是十分复杂,这里简单列出输入图像经过ViT的流程,其实整体只有4个步骤:

  • 输入图像进入网络,先经过一个卷积base的patch_embedding:取16*16为一个patch,步长也是16,这样feature map的尺寸就缩小了16倍,同时channel从3映射到768。
  • patch_embed过后加positional_embedding:positional_embedding是个可学习的参数矩阵,初始化是0。
  • 加了positional_embedding后的feature map过16个transformer block,其中12个transformer是基于window partition(就是把特征图分成14*14的windows做局部的attention)的attn模块,和4个全局attn,这4个全局attn是穿插在windowed attention中的。
  • 最后过两层卷积(neck)把channel数降到256,这就是最终的image embedding的结果。

整体来看,这个部分的计算量是相对来说比较大的,demo体验过程中,只有这个过程的计算是在fb的服务器上做的,prompt encoder和mask decoder体积比较小,都是在浏览器内部或者说用本地的内存跑的,整体速度还比较快。

prompt encoder

prompt encoder负责prompt到prompt的特征空间,映射出的特征的channel和image embedding的channel一致,因为这两个后边要用attention进行融合。这里先不讨论text prompt的情况,nlp的映射处理我个人不太了解。

如果prompt是point,那么它的映射由两个部分相加组成,一个是位置编码,这里的位置编码使用的是Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains的编码方式,这里我没有仔细研读原论文,不是十分确定,大概的意思是,用空间坐标乘以高斯分布的向量来描述位置比直接的线性向量描述效果更好,另一个部分是一个描述当前点是前景还是背景(因为demo里可以选择pos点也可以选择neg点)特征的可学习的一维向量。换句话说,如果当前选择的点是positive,那么就在位置编码的2维向量上加一个表示postitive的一维向量,如果是neg,就加一个表示neg的一维向量,对于所有的positive的点,加上去的pos向量都是一样的。

如果prompt是box,那box的映射也是由两个部分相加组成,第一部分是左上和右下两个点的位置编码,第二部分是一组一维向量用来描述这个点是“左上”还是“右下”。也就是说,对于左上的点,他的映射就是位置编码+“左上”这个特征的描述向量,右下的点,就是位置编码+“右下”这个特征的描述向量。

上面的两个都属于sparse prompt,那么对于mask这类的dense prompt,他的映射就比较简单粗暴。在输入prompt encoder之前,先要把mask降采样到4x,再过两个2x2,stride=2的卷积,这样尺寸又降了4x,就和降了16x的图像特征图尺寸一致了,再过一个1*1的卷积,把channel也升到256。如果没有提供mask,也就是我们实际inference时候的场景,这个结构会直接返回一个描述“没有mask”特征的特征图。

mask decoder

mask decoder是整个模型中结构相对来说唯一比较复杂的部分,这里放一张paper里的结构图:

mask decoder结构,略显盘根错节

decoder的结构之所以看起来复杂,主要原因是prompt embedding和image embedding在这个结构中反复融合并且反复更新,从这里同样可以看出prompt在这个任务中的重要地位。

我们从左至右逐步分析decoder的流程,

  • 在prompt embedding进入decoder之前,先在它上面concat了一组可学习的output tokens,output tokens由两个部分构成:
    • 一个是iou token,它会在后面被分离出来用于预测iou的可靠性(对应结构图右侧的IoU output token),它受到模型计算出的iou与模型计算出的mask与GT实际的iou之间的MSE loss监督;
    • 另一个是mask token,它也会在后面被分离出来参与预测最终的mask(对应结构图右侧的output token per mask),mask受到focal lossdice loss 20:1的加权组合监督。
    • 这两个token的意义我感觉比较抽象,因为理论来说进入decoder的变量应该是由模型的输入,也就是prompt和image的映射构成,但这两个token的定义与prompt和image完全没有关系,而是凭空出现的。从结果反推原因,只能把它们理解成对模型的额外约束,因为它们两个参与构成了模型的两个输出并且有loss对他们进行监督。
    • 最终prompt embedding(这一步改名叫prompt token)和刚才提到这两个token concat到一起统称为tokens进入decoder。
  • image embedding在进入decoder之前也要进行一步操作:dense prompt由于包含密集的空间信息,与image embedding所在的特征空间一致性更高,所以直接与image embedding相加融合。因为后面要与prompt做cross attention融合,这里还要先算一下image embedding的位置编码。
  • 接下来{image embedding,image embedding的位置编码,tokens}进入一个两层transformer结构的decoder做融合。值得注意的是,在transformer结构中,为了保持位置信息始终不丢失,每做一次attention运算,不管是self-attention还是cross-attention,tokens都叠加一次初始的tokens,image embedding都叠加一次它自己的位置编码,并且每个attention后边都接一个layer_norm。
    • tokens先过一个self-attention。
    • tokens作为q,对image embedding做cross attention,更新tokens。
    • tokens再过两层的mlp做特征变换。
    • image embedding作为q,对tokens做cross attention,更新image embedding。
  • 更新后的tokens作为q,再对更新后的image embedding做cross attention,产生最终的tokens。
  • 更新后的image embedding过两层kernel_size=2, stride=2的转置卷积,升采样到4x大小(依然是4x降采样原图的大小),产生最终的image embedding。
  • 接下来兵分两路:
    • mask token被从tokens中分离出来(因为他一开始就是concat上去的,可以直接按维度摘出来),过一个三层的mlp调整channel数与最终的image embedding一致,并且他们两个做矩阵乘法生成mask的预测。
    • iou token被从tokens中分离出来,也过一个三层的mlp生成最终的iou预测。
  • 最后,如前文所述,分别对mask的预测和iou预测进行监督,反向传播,更新参数。

总结

还是那句话,整体上来说,SAM的结构说不上十分复杂,除了mask decoder有一点绕之外,整体的结构都是直上直下。我个人觉得在学习SAM的过程中一定要保持的一个认知是,它是一个新的promptable segmentation任务,它对prompt所指定的局部信息的关注度更高,而不像传统的语义分割通常要在全局搜索目标,所以要跳出语义分割中对于邻域,对于多尺度的执念,也许能更容易地理解SAM的思路。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值