1.背景
在网络数据集上预训练的大语言模型具有强大的zero-shot(零样本)和few-shot(少样本)的泛化能力,这些"基础模型"可以推广到超出训练过程中的任务和数据分布,这种能力通过“prompt engineering”实现,具体就是输入提示语得到有效的文本输出,使用网络上的大量文本资料库进行缩放和训练后,发现这种零样本和少样本的训练的模型比微调模型效果还要好,数据集越大,效果越明显。
视觉任务上也对这种基础模型进行了探索,比如CLIP和ALIGN利用对比学习,将文本和图像编码进行了对齐,通过提示语生成image encoder,就可以扩展到下游任务,比如生成图像。
论文的目的是建立一个图像分割的基础模型,开发一个具有提示能力的模型。
要解决的3个问题:
(1)什么任务可以实现零样本?
通过提示输入,生成有效的mask,当提示是不确定的,能生成多个objects(比如衣服上的一个点,既可以表示衣服,也表示穿衣服的人),如下图所示:提示可以是点,矩形框,文字,mask,或者是图像。
prompt提示
(2)模型结构应该是什么样?
模型要支持灵活的提示,且要实时生成mask,对输出也是模糊的(比如表示衣服还是穿衣服的人),设计结构如下:一个prompt encoder,对提示进行编码,image encoder对图像编码,生成embedding, 最后融合2个encoder,再接一个轻量的mask decoder,输出最后的mask。
prompt encoder
(3)数据怎么支持这些任务?
需要一个大量且多样化的mask数据。自然语言数据是通过在线获取,但是mask数据是不足的,需要一个替代策略。
方案就是建立一个“数据引擎”,分成3步:人工辅助(帮助标注,类似交互式分割),半自动(通过提供提示,自动生成对象mask),全自动(通过规则格网作为提示,进行自动生成)。
如下图所示:先标注数据进行训练模型,然后用模型辅助标注数据,如此建立一个数据循环。最终从1100万张图像中生成了10亿的mask,是当前最大的数据,比当前已有的数据集多了400倍的mask。
data engine
2.模型
模型结构如下
模型结构
2.1 image encoder
利用mae预训练的vit,最低限度适应高分辨率的输入,该encoder在prompt encoder之前,对每张图像只运行一次。
输入(c,h,w)的图像,对图像进行缩放,按照长边缩放成1024,短边不够就pad,得到(c,1024,1024)的图像,经过image encoder,得到对图像16倍下采样的feature,大小为(256,64,64)。
2.2 prompt encoder
- 分成2类:稀疏的(点,box,文本),稠密的(mask)
- point:映射到256维的向量,包含代表点位置的 positional encoding,加2个代表该点是前景/背景的可学习的embedding。
- box:用一个embedding对表示(1)可学习的embedding代表左上角(2)可学习的embedding代表右下角
- 文本:通过CLIP模型进行文本编码
- mask:用输入图像1/4分辨率的mask,然后用(2,2)卷积核,stride-2输出channel为4和16,再用(1,1)卷积核将channel升到256. mask 和iamge embedding通过element-wise相乘(逐元素相乘,可以理解成mask的feature对image的feature进行加权)
2.3 mask decoder
mask decoder模块
- 在prompt embeddings中插入一个可学习的token,用于docoder的输出。
(1)prompt toekns+output tokens进行self attn,
(2)用得到的token和image embedding进行 cross attn(token作为Q)
(3)point-wise MLP 更新token
(4)用image embedding和(3)的token进行cross atten(image embedding作为Q)
重复上述步骤2次,再将attn再通过残差进行连接,最终输出masks和iou scores。
为了解决输出模糊性问题(一个提示可能生成多个mask,比如衣服上的一个点,既可以表示衣服,也表示穿衣服的人),预测输出多个masks(发现**整体,部分,子部分**已经足够描述mask),在训练过程中,只回传最小的loss,为了对mask进行排序,增加一个小的head预测mask和目标的iou。
当输入多个提示时,生成的mask会比较接近,为了减少loss退化和确保获取明确的mask,此时只预测一个mask(作为第4个预测mask,只有多个提示时才预测,当单个提示时不用)
2.4 模型训练
训练时模拟交互分割的过程,从目标mask中随机选取前景点或者box,点是从gt mask选取,box增加长边10%的噪声,最大20像素。
在第一次prompt预测mask之后,后续是从预测mask和gt mask有差异的区域采样点,如果新生成的点是FN,则作为前景,如果是FP,则作为背景。同时,将预测的mask(unthresholded mask logits代替二值化的mask,不过滤阈值,默认为0),作为prompt作为迭代。
训练过程中,发现用8个采样点比较合适(对比16个,没有明显增益),为了鼓励模型从mask中获益,其中2个迭代不用新采样的点,总共11个迭代,1一个是初始化的prompt输入,然后是8个上述迭代,再加2个不重新采样点的迭代(这样可以refine mask)。由于mask decoder比较轻,所以可以进行更多次的迭代。
* loss
mask 用focal loss和dice loss进行线性组合,系数(20:1),iou 用mse loss。
* 训练时间
256 A100 GPUs,3-5天(jd看了下,A100价格6万左右,256个,1000多万,money is all you need)
3.data engine(数据引擎)
- 辅助人工标注
通过SAM基于浏览器的交互式分割工具,通过“brush”和"eraser"工具,进行标注。模型可以实时输出mask,建议标注者优先标记他们命名的对象,按图层顺序标记,如果一个mask标记超过30s,先处理下一张。
SAM先用公开数据集训练,然后再用新增的标注mask训练。随着数据越多,image-encoder的能力越强,retrained了6次。随着模型改进,每个mask平均标注时间从34s到14s,平均每张图像mask从22增加到44个。在这个过程中,从12万图像中,收集了430万个mask。
- 半自动
增加mask的多样性,首先检测出可信的mask,然后用预测mask填充图像,让标注者标注未标记的mask。为了检测可信的mask,先用第一步的mask训练了一个类别一样的box检测器。半自动过程中,从18万张图像中生成了590万个mask。用新收集的数据,重新训练模型,平均标注时间又回到了34s,因为新的mask都是比较有难度的。每张图像上mask从44增加到72。
- 全自动
利用前2步,得到的大量的和多样性的mask,结合模型可以根据不明确的输入也能输出有效的mask(参考mask encoder),对图像生成(32,32)个格网点,每个点预测一系列mask,如果一个点落在部分、子部分上,模型返回部分、子部分和整体的object。同时,通过预测的iou筛选 confident(可信的mask),选取一个stable的mask(稳定的mask,在相似的mask中,概率阈值在 0.5-δ和 0.5-δ之间);最后,通过nms过滤confident和stable中重复的mask。
为了提高mask比较小的,还通过放大图像进行crop,处理多个mask覆盖的情况。
在1100万数据集上,生成了11亿高质量的mask。
- 数据情况
* 图片:从合作商获取1100万张图像,按短边重采样到1500像素。
* mask:99.1%都是自动生成的,通过对比分析,自动生成的mask质量也是非常高的。为了评估质量,随机选500张图像(约5万个mask),让专业的标注人员进行标注,通过对比发现94%的mask有90%以上的iou。
* 数据分布更广,从全世界获取数据,mask更多,数据偏向性较小。
4.实验
论文做了几个实验,此处仅列instance segmentation和Text-to-Mask结果,其他可以参考论文
- instance segmentation
对比用ViTDet-H在COCO和LVIS上训练,SAM自动生成,发现SAM接近ViTDet训练的结果。在边界清晰处,SAM表现更好。为了调查这个情况,让标注人员对mask质量从0-10进行评分,可以发现在mask质量比较好的情况下,SAM产生的mask质量更好。
对比vit在coco和lvis监督效果
- Text-to-Mask
通过文本输入就可以自动分割,因为该功能是探索性的,demo中未开放该功能,也许这是分割任务的未来形态。
文本提示分割
- 思考与测试
- 虽然SAM很强,但也有一些不足:会生成一些不连续的mask,错过小目标,边界可能不够清晰;text-to-mask尚不稳定;处理提示可以实时,但对图像的encoder不是实时的,用的vit。
测试了下,在4090显卡上,推理一张(1800,1200)的图像,大概用了0.233s,显存用了6.5G,其中大部分时间都耗在image encoder上(0.231s)。 - 可以将SAM应用到更多方向,用SAM分割对象,并泛化到未见过的对象,有研究用SAM生成检测的box,集成到标注工具,自动标注数据。可以参考项目,了解SAM的扩展应用。
- 虽然SAM很强,但也有一些不足:会生成一些不连续的mask,错过小目标,边界可能不够清晰;text-to-mask尚不稳定;处理提示可以实时,但对图像的encoder不是实时的,用的vit。
Grounded-Segment-Anythinggithub.com/IDEA-Research/Grounded-Segment-Anything
https://github.com/anuragxel/saltgithub.com/anuragxel/salt
GitHub - anuragxel/salt: Segment Anything Labelling Tool
salt半自动标注github.com/anuragxel/salt
SegDrawer:web端标注github.com/lujiazho/SegDrawer
-
- SAM目前生成的mask都是无标签的,对于需要类别的mask,可以考虑增加head,监督训练类别。可以参考SSA这个项目,利用COCO和ADE20K生成一个类别标签,再结合Open-vocabulary生成多个类别标签,最终过滤得到多个可能的标签。
Semantic-Segment-Anythinggithub.com/fudan-zvg/Semantic-Segment-Anything
-
- 推理部署,测试将模型转成onnx时,没有包含image encoder部分,先需要得到图像的image_embedding,再用onnx去推理。已经有分支支持转onnx和量化。
image encoder转onnx代码github.com/visheratin/segment-anything/blob/main/scr