SAM,分割一切

1.背景

在网络数据集上预训练的大语言模型具有强大的zero-shot(零样本)和few-shot(少样本)的泛化能力,这些"基础模型"可以推广到超出训练过程中的任务和数据分布,这种能力通过“prompt engineering”实现,具体就是输入提示语得到有效的文本输出,使用网络上的大量文本资料库进行缩放和训练后,发现这种零样本和少样本的训练的模型比微调模型效果还要好,数据集越大,效果越明显。
视觉任务上也对这种基础模型进行了探索,比如CLIPALIGN利用对比学习,将文本和图像编码进行了对齐,通过提示语生成image encoder,就可以扩展到下游任务,比如生成图像
论文的目的是建立一个图像分割的基础模型,开发一个具有提示能力的模型。

要解决的3个问题:

(1)什么任务可以实现零样本

通过提示输入,生成有效的mask,当提示是不确定的,能生成多个objects(比如衣服上的一个点,既可以表示衣服,也表示穿衣服的人),如下图所示:提示可以是点,矩形框,文字,mask,或者是图像。

prompt提示

(2)模型结构应该是什么样?

模型要支持灵活的提示,且要实时生成mask,对输出也是模糊的(比如表示衣服还是穿衣服的人),设计结构如下:一个prompt encoder,对提示进行编码,image encoder对图像编码,生成embedding, 最后融合2个encoder,再接一个轻量的mask decoder,输出最后的mask。

prompt encoder

(3)数据怎么支持这些任务?

需要一个大量且多样化的mask数据。自然语言数据是通过在线获取,但是mask数据是不足的,需要一个替代策略。

方案就是建立一个“数据引擎”,分成3步:人工辅助(帮助标注,类似交互式分割),半自动(通过提供提示,自动生成对象mask),全自动(通过规则格网作为提示,进行自动生成)。

如下图所示:先标注数据进行训练模型,然后用模型辅助标注数据,如此建立一个数据循环。最终从1100万张图像中生成了10亿的mask,是当前最大的数据,比当前已有的数据集多了400倍的mask。

data engine

2.模型

模型结构如下

模型结构

2.1 image encoder


利用mae预训练的vit,最低限度适应高分辨率的输入,该encoder在prompt encoder之前,对每张图像只运行一次。
输入(c,h,w)的图像,对图像进行缩放,按照长边缩放成1024,短边不够就pad,得到(c,1024,1024)的图像,经过image encoder,得到对图像16倍下采样的feature,大小为(256,64,64)。

2.2 prompt encoder

  • 分成2类:稀疏的(点,box,文本),稠密的(mask)
    • point:映射到256维的向量,包含代表点位置的 positional encoding,加2个代表该点是前景/背景的可学习的embedding。
    • box:用一个embedding对表示(1)可学习的embedding代表左上角(2)可学习的embedding代表右下角
    • 文本:通过CLIP模型进行文本编码
    • mask:用输入图像1/4分辨率的mask,然后用(2,2)卷积核,stride-2输出channel为4和16,再用(1,1)卷积核将channel升到256. mask 和iamge embedding通过element-wise相乘(逐元素相乘,可以理解成mask的feature对image的feature进行加权)

2.3 mask decoder

mask decoder模块

  • 在prompt embeddings中插入一个可学习的token,用于docoder的输出。

(1)prompt toekns+output tokens进行self attn,

(2)用得到的token和image embedding进行 cross attn(token作为Q)

(3)point-wise MLP 更新token

(4)用image embedding和(3)的token进行cross atten(image embedding作为Q)

重复上述步骤2次,再将attn再通过残差进行连接,最终输出masks和iou scores。

为了解决输出模糊性问题(一个提示可能生成多个mask,比如衣服上的一个点,既可以表示衣服,也表示穿衣服的人),预测输出多个masks(发现**整体,部分,子部分**已经足够描述mask),在训练过程中,只回传最小的loss,为了对mask进行排序,增加一个小的head预测mask和目标的iou。

当输入多个提示时,生成的mask会比较接近,为了减少loss退化和确保获取明确的mask,此时只预测一个mask(作为第4个预测mask,只有多个提示时才预测,当单个提示时不用)

2.4 模型训练

训练时模拟交互分割的过程,从目标mask中随机选取前景点或者box,点是从gt mask选取,box增加长边10%的噪声,最大20像素。

在第一次prompt预测mask之后,后续是从预测mask和gt mask有差异的区域采样点,如果新生成的点是FN,则作为前景,如果是FP,则作为背景。同时,将预测的mask(unthresholded mask logits代替二值化的mask,不过滤阈值,默认为0),作为prompt作为迭代。

训练过程中,发现用8个采样点比较合适(对比16个,没有明显增益),为了鼓励模型从mask中获益,其中2个迭代不用新采样的点,总共11个迭代,1一个是初始化的prompt输入,然后是8个上述迭代,再加2个不重新采样点的迭代(这样可以refine mask)。由于mask decoder比较轻,所以可以进行更多次的迭代。

* loss

mask 用focal loss和dice loss进行线性组合,系数(20:1),iou 用mse loss。

* 训练时间

256 A100 GPUs,3-5天(jd看了下,A100价格6万左右,256个,1000多万,money is all you need

3.data engine(数据引擎)

  • 辅助人工标注

通过SAM基于浏览器的交互式分割工具,通过“brush”和"eraser"工具,进行标注。模型可以实时输出mask,建议标注者优先标记他们命名的对象,按图层顺序标记,如果一个mask标记超过30s,先处理下一张。

SAM先用公开数据集训练,然后再用新增的标注mask训练。随着数据越多,image-encoder的能力越强,retrained了6次。随着模型改进,每个mask平均标注时间从34s到14s,平均每张图像mask从22增加到44个。在这个过程中,从12万图像中,收集了430万个mask。

  • 半自动

增加mask的多样性,首先检测出可信的mask,然后用预测mask填充图像,让标注者标注未标记的mask。为了检测可信的mask,先用第一步的mask训练了一个类别一样的box检测器。半自动过程中,从18万张图像中生成了590万个mask。用新收集的数据,重新训练模型,平均标注时间又回到了34s,因为新的mask都是比较有难度的。每张图像上mask从44增加到72。

  • 全自动

利用前2步,得到的大量的和多样性的mask,结合模型可以根据不明确的输入也能输出有效的mask(参考mask encoder),对图像生成(32,32)个格网点,每个点预测一系列mask,如果一个点落在部分、子部分上,模型返回部分、子部分和整体的object。同时,通过预测的iou筛选 confident(可信的mask),选取一个stable的mask(稳定的mask,在相似的mask中,概率阈值在 0.5-δ和 0.5-δ之间);最后,通过nms过滤confidentstable中重复的mask。

为了提高mask比较小的,还通过放大图像进行crop,处理多个mask覆盖的情况。

在1100万数据集上,生成了11亿高质量的mask。

  • 数据情况

* 图片:从合作商获取1100万张图像,按短边重采样到1500像素。

* mask:99.1%都是自动生成的,通过对比分析,自动生成的mask质量也是非常高的。为了评估质量,随机选500张图像(约5万个mask),让专业的标注人员进行标注,通过对比发现94%的mask有90%以上的iou。

* 数据分布更广,从全世界获取数据,mask更多,数据偏向性较小。

4.实验

论文做了几个实验,此处仅列instance segmentation和Text-to-Mask结果,其他可以参考论文

  • instance segmentation

对比用ViTDet-H在COCO和LVIS上训练,SAM自动生成,发现SAM接近ViTDet训练的结果。在边界清晰处,SAM表现更好。为了调查这个情况,让标注人员对mask质量从0-10进行评分,可以发现在mask质量比较好的情况下,SAM产生的mask质量更好。

对比vit在coco和lvis监督效果

  • Text-to-Mask

通过文本输入就可以自动分割,因为该功能是探索性的,demo中未开放该功能,也许这是分割任务的未来形态。

文本提示分割

  • 思考与测试
    • 虽然SAM很强,但也有一些不足:会生成一些不连续的mask,错过小目标,边界可能不够清晰;text-to-mask尚不稳定;处理提示可以实时,但对图像的encoder不是实时的,用的vit。
      测试了下,在4090显卡上,推理一张(1800,1200)的图像,大概用了0.233s,显存用了6.5G,其中大部分时间都耗在image encoder上(0.231s)。
    • 可以将SAM应用到更多方向,用SAM分割对象,并泛化到未见过的对象,有研究用SAM生成检测的box,集成到标注工具,自动标注数据。可以参考项目,了解SAM的扩展应用。

Grounded-Segment-Anything​github.com/IDEA-Research/Grounded-Segment-Anything

https://github.com/anuragxel/salt​github.com/anuragxel/salt

GitHub - anuragxel/salt: Segment Anything Labelling Tool

salt半自动标注​github.com/anuragxel/salt

SegDrawer:web端标注​github.com/lujiazho/SegDrawer

    • SAM目前生成的mask都是无标签的,对于需要类别的mask,可以考虑增加head,监督训练类别。可以参考SSA这个项目,利用COCO和ADE20K生成一个类别标签,再结合Open-vocabulary生成多个类别标签,最终过滤得到多个可能的标签。

Semantic-Segment-Anything​github.com/fudan-zvg/Semantic-Segment-Anything

    • 推理部署,测试将模型转成onnx时,没有包含image encoder部分,先需要得到图像的image_embedding,再用onnx去推理。已经有分支支持转onnx和量化。

image encoder转onnx代码​github.com/visheratin/segment-anything/blob/main/scr

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值