SAM相关下游任务研习

#! https://zhuanlan.zhihu.com/p/662956503
这是一篇SAM领域下游任务的学习笔记。
原文:

A Comprehensive Survey on Segment Anything
Model for Vision and Beyond

SAM for image processing

Software Scenes

Image editing

可以应用于去除物体,填充物体,替代物体。和AIGC一起应用。
9042e7705665982548c60e13307e91cc
例子:Edit everything

模型采用SAM+CLIP+Stable Diffusion,先用SAM找到所有分割,再根据source prompt给分割排序通过CLIP找到最高分的目标分割,再写一个target prompt通过Stable Diffusion来生成新物品代替目标。
没有提出什么新代码

Style Transfer

将原图片的风格转移成给定图片的风格。SAM的promptable区域选择让用户能够选定区域进行风格迁移。
2fb9641d04f48bc9b0c57fb2e006c4bc

例子:Any-to-Any Style Transfer: Making Picasso and Da Vinci Collaborate

通过把SAM和其他的风格转换模型结合起来实现风格迁移。没有训练新模型,网络结构如下:

 """ Building Models """
    transformer_path = 'ckpt/latest_net_transformer.pth'
    decoder_path = 'ckpt/latest_net_decoder.pth'
    ada_attn_3_path = 'ckpt/latest_net_adaattn_3.pth'
    vgg_path = 'ckpt/vgg_normalised.pth'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_encoder = nn.Sequential(
        nn.Conv2d(3, 3, (1, 1)),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(3, 64, (3, 3)),
        nn.ReLU(),  # relu1-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(64, 64, (3, 3)),
        nn.ReLU(),  # relu1-2
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(64, 128, (3, 3)),
        nn.ReLU(),  # relu2-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(128, 128, (3, 3)),
        nn.ReLU(),  # relu2-2
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(128, 256, (3, 3)),
        nn.ReLU(),  # relu3-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 256, (3, 3)),
        nn.ReLU(),  # relu3-4
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(256, 512, (3, 3)),
        nn.ReLU(),  # relu4-1, this is the last layer used
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu4-4
        nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-1
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-2
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU(),  # relu5-3
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(512, 512, (3, 3)),
        nn.ReLU()  # relu5-4
    )
    image_encoder.load_state_dict(torch.load(vgg_path))
    enc_layers = list(image_encoder.children())
    enc_1 = nn.Sequential(*enc_layers[:4]).to(device)
    enc_2 = nn.Sequential(*enc_layers[4:11]).to(device)
    enc_3 = nn.Sequential(*enc_layers[11:18]).to(device)
    enc_4 = nn.Sequential(*enc_layers[18:31]).to(device)
    enc_5 = nn.Sequential(*enc_layers[31:44]).to(device)
    image_encoder_layers = [enc_1, enc_2, enc_3, enc_4, enc_5]
    for layer in image_encoder_layers:
        layer.eval()
        for p in layer.parameters():
            p.requires_grad = False
    transformer = Transformer(in_planes=512, key_planes=512 + 256 + 128 + 64).to(device)
    decoder = Decoder().to(device)
    ada_attn_3 = AdaAttN(in_planes=256, key_planes=256 + 128 + 64, max_sample=64 * 64).to(device)
    transformer.load_state_dict(torch.load(transformer_path))
    decoder.load_state_dict(torch.load(decoder_path))
    ada_attn_3.load_state_dict(torch.load(ada_attn_3_path))
    transformer.eval()
    decoder.eval()
    ada_attn_3.eval()
    for p in transformer.parameters():
        p.requires_grad = False
    for p in decoder.parameters():
        p.requires_grad = False
    for p in ada_attn_3.parameters():
        p.requires_grad = False

    def encode_with_intermediate(img):
        results = [img]
        for i in range(len(image_encoder_layers)):
            func = image_encoder_layers[i]
            results.append(func(results[-1]))
        return results[1:]

    def style_transfer():
        with torch.no_grad():
            style = img_to_tensor(cv2.cvtColor(padding(style_im, 32), cv2.COLOR_BGR2RGB)).to(device)
            content = img_to_tensor(cv2.cvtColor(padding(content_im, 32), cv2.COLOR_BGR2RGB)).to(device)
            c_masks = [torch.from_numpy(padding(m, 32)).unsqueeze(0).permute(0, 3, 1, 2).float().to(device)
                       for m in all_mask_c]
            s_masks = [torch.from_numpy(padding(m, 32)).unsqueeze(0).permute(0, 3, 1, 2).float().to(device)
                       for m in all_mask_s]
            c_feats = encode_with_intermediate(content)
            s_feats = encode_with_intermediate(style)
            c_adain_feat_3 = ada_attn_3(c_feats[2], s_feats[2], get_key(c_feats, 2), get_key(s_feats, 2), None,
                                        c_masks, s_masks)
            cs = transformer(c_feats[3], s_feats[3], c_feats[4], s_feats[4], get_key(c_feats, 3), get_key(s_feats, 3),
                             get_key(c_feats, 4), get_key(s_feats, 4), None, c_masks, s_masks)
            cs = decoder(cs, c_adain_feat_3)
            cs = tensor_to_img(cs[:, :, :h, :w])
            cs = cv2.cvtColor(cs, cv2.COLOR_RGB2BGR)
            return cs

计算出来style掩码,原图掩码,编码后原图送进ada_attn_3处理,处理后和style一起送进transforer,再返回解码器。

Real-World Scenes

Detection

实验表明SAM对自然图像处理较好,对低出现具体应用场景应用比较差,需要很强的先验知识。
8b74283abdcc2382b3ee9f93dbfda42a

Counting

实验表明,SAM在few-shot侦测small,crowded物品实验中的表现与baseline相比还有一定差距。

Moving Object

例子:DSEC-MOS: Segment Any Moving Object with
Moving Ego Vehicle

Moving Object分割是CV中很困难的一个问题,现有的数据集主要是RGB或者Lidar videos,缺少动态场景事件信息。DSEC-MOS提出了新的用于自动驾驶移动物体分割的数据集,结合了基于事件的视觉数据,通过使用SAM和DSEC-MOD提供的移动物体边界框,生成准确的分割掩码注释。

复杂场景

Low-Contrast Scene

例子:SAM Struggles in Concealed Scenes –
Empirical Study on “Segment Anything”

本文对三个隐蔽场景(伪装动物、工业缺陷和医学病变)进行评估,发现SAM在隐蔽背景中表现不佳

Leaf Only SAM: A Segment Anything Pipeline for Zero-Shot Automated Leaf Segmentation

本文利用SAM和一系列后处理步骤来进行马铃薯叶片的分割,不需要任何训练数据,适用于植物表型分析等领域。与在小性感马铃薯叶片数据集上微调的Mask R-CNN模型相比性能略低,但是SAM是zero-shot分割,具有零样本分类器的潜力。
代码:

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
import pandas as pd
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device);
folder = 'string of path to folder of images'
folder_out = 'string of path to where you want to save output'

imnames = [x for x in os.listdir(folder) if '.JPG' in x]  # get list of image files change .JPG is using files of different type
def checkcolour(masks, hsv):
    colours = np.zeros((0,3))

    for i in range(len(masks)):
        color = hsv[masks[i]['segmentation']].mean(axis=(0))
        colours = np.append(colours,color[None,:], axis=0)
        
    idx_green = (colours[:,0]<75) & (colours[:,0]>35) & (colours[:,1]>35)
    if idx_green.sum()==0:
        # grow lights on adjust
        idx_green = (colours[:,0]<100) & (colours[:,0]>35) & (colours[:,1]>35)
    
    return(idx_green)
def checkfullplant(masks):
    mask_all = np.zeros(masks[0]['segmentation'].shape[:2])

    for mask in masks:
        mask_all +=mask['segmentation']*1
        
    iou_withall = []
    for mask in masks:
        iou_withall.append(iou(mask['segmentation'], mask_all>0))
        
    idx_notall = np.array(iou_withall)<0.9
    return idx_notall
def getbiggestcontour(contours):
    nopoints = [len(cnt) for cnt in contours]
    return(np.argmax(nopoints))

def checkshape(masks):
    cratio = []

    for i in range(len(masks)):
        test_mask = masks[i]['segmentation']
        
        if not test_mask.max():
            cratio.append(0)
        else:

            contours,hierarchy = cv2.findContours((test_mask*255).astype('uint8'), 1, 2)

            # multiple objects possibly detected. Find contour with most points on it and just use that as object
            cnt = contours[getbiggestcontour(contours)]
            M = cv2.moments(cnt)

            area = cv2.contourArea(cnt)
            perimeter = cv2.arcLength(cnt,True)

            (x,y),radius = cv2.minEnclosingCircle(cnt)

            carea = np.pi*radius**2

            cratio.append(area/carea)
    idx_shape = np.array(cratio)>0.1
    return(idx_shape)
def iou(gtmask, test_mask):
    intersection = np.logical_and(gtmask, test_mask)
    union = np.logical_or(gtmask, test_mask)
    iou_score = np.sum(intersection) / np.sum(union)
    return (iou_score)
def issubset(mask1, mask2):
    # is mask2 subpart of mask1
    intersection = np.logical_and(mask1, mask2)
    return(np.sum(intersection)/mask2.sum()>0.9)

def istoobig(masks):
    idx_toobig = []
    
    mask_all = np.zeros(masks[0]['segmentation'].shape[:2])

    for mask in masks:
        mask_all +=mask['segmentation']*1 

    for idx in range(len(masks)):
        if idx in idx_toobig:
            continue
        for idx2 in range(len(masks)):
            if idx==idx2:
                continue
            if idx2 in idx_toobig:
                continue
            if issubset(masks[idx2]['segmentation'], masks[idx]['segmentation']):
                # check if actually got both big and small copy delete if do
                if mask_all[masks[idx2]['segmentation']].mean() > 1.5:
                
                    idx_toobig.append(idx2)
    
    idx_toobig.sort(reverse=True)        
    return(idx_toobig)

def remove_toobig(masks, idx_toobig):
    masks_ntb = masks.copy()

    idx_del = []
    for idxbig in idx_toobig[1:]:
        maskbig = masks_ntb[idxbig]['segmentation'].copy()
        submasks = np.zeros(maskbig.shape)

        for idx in range(len(masks_ntb)):
            if idx==idxbig:
                continue
            if issubset(masks_ntb[idxbig]['segmentation'], masks_ntb[idx]['segmentation']):
                submasks +=masks_ntb[idx]['segmentation']

        if np.logical_and(maskbig, submasks>0).sum()/maskbig.sum()>0.9:
            # can safely remove maskbig
            idx_del.append(idxbig)
            del(masks_ntb[idxbig])
            
    return(masks_ntb)
for imname in imnames:
    print(imname)
    image = cv2.imread(folder + imname)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image,None,fx=0.5,fy=0.5)   # downsize image to fit on gpu easier may not be needed
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
           
    # use crop_n_layer=1 to improve results on smallest leaves 
    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.88,
        stability_score_thresh=0.95,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=200,  
    )

    # get masks
    masks = mask_generator.generate(image)
    
    # remove things that aren't green enough to be leaves
    idx_green = checkcolour(masks,hsv)

    masks_g = []
    for idx, use in enumerate(idx_green):
        if use:
            masks_g.append(masks[idx])

    if len(masks_g) > 2:

        # check to see if full plant detected and remove
        idx_notall = checkfullplant(masks_g)

        masks_na = []

        for idx, use in enumerate(idx_notall):
            if use:
                masks_na.append(masks_g[idx])

    else:
        masks_na = masks_g

    idx_shape = checkshape(masks_na)

    masks_s = []
    for idx, use in enumerate(idx_shape):
        if use:
            masks_s.append(masks_na[idx])

    idx_toobig = istoobig(masks_s)
    masks_ntb = remove_toobig(masks_s, idx_toobig)
    
    # save results at each step as npz file 
    np.savez(folder_out + imname.replace('.JPG','leafonly_allmasks.npz'),
              masks, masks_g, masks_na, masks_s, masks_ntb)

例子:Segment Any Anomaly without Training via Hybrid Prompt Regularization, SAA+

本文提出一种新的框架 Segment Any Anomaly +(SAA+)用于zero-shot异常分割定位,允许更准确的一场分割,无需特定领域的调优。文章发现基础模型的一个简单的组合存在严重的语言歧义。因此引入了基于领域专家知识和目标图像上下文的混合提示,以减轻语言歧义。框架说明如下:
20231023202847
论文通过将基础模型组合起来,利用多模态先验知识来定位异常区域。引入了来自领域专家知识和目标图像上下文的混合提示作为规范化。实现了SOTA。

Thermal Infrared Image

例子:Learning to “Segment Anything” in Thermal Infrared Images through Knowledge Distillation with a Large Scale Dataset SATIR
20231023204239
热红外图像场景是另一种复杂的场景,图像总是较暗且难以注释。因此,大量的未标记数据被浪费,该领域的模型无法以可靠的方式学习到高精度。为了解决这一问题,文章利用SAM生成伪标签,构建了一个大型热红外分割数据集SATIR,
进行模型预训练,其中包含超过100,000张带有像素注释标签的图像。为了最终提高模型在该领域的性能,作者提出了一个三步框架,如图9所示。他们用SAM构建上述数据集,然后用它对模型进行预训练。然后,他们对目标任务的预训练模型进行微调。在公开的热红外语义分割数据SODA上的实验验证了其在该领域的有效性,其中由SATIR训练的主干模型的平均Intersection over Union (mIoU)约为1.3%,优于其他模型。

Overhead Image

例子:SAMRS: Scaling-up Remote Sensing Segmentation Dataset with Segment Anything Model

本文将SAM应用于火星表面遥感影像的地质填土任务,由于缺乏问题特定偏差,不能直接应用于特定领域任务。因此他们通过引入特定领域的解码器来改变SAM的设计,通过仅使用5张标记图像的知识蒸馏进行微调来学习特定问题的语义,表现效果很好。

VISION AND BEYOND

Vision Related Application

Medical Imaging

根据医学图像的成像格式,SAM在医学图像分割中的应用可分为:CT,磁共振,结肠镜图像,H&E年龄染色组织学切片,多格式图像等。

  • CT

    研究表明SAM可以有效地推广到CT数据。SAMed是医学图像分割的解决方案,利用了SAM并进行LoRA微调,经过训练的SAMed模型达到了与SOTA方法相当的性能。并且由于SAMed只更新SAM参数的一小部分,因此在实际使用中部署和存储成本非常低。

  • MRI

    研究将SAM与FSL的脑提取工具(BET)进行了比较,结果表明,SAM在各项系数上都比BET表现良好,特别是图像质量受到信号和病变影响的情况下。此外,SAM具有优越的分割特性,有潜力成为一个更准确、健壮和通用的工具。

总结:医学图像方面大部分是对SAM进行微调以完成具体任务。在这个过程中,SAM展现了优异的分割性能,许多微调模型达到了SOTA水平。然而,在一些涉及血液、反射、模糊和阴影的复杂手术场景中,SAM无法识别器械。此外,SAM的性能不够健壮,无法承受各种形式的数据损坏。

Video

应用于视频目标跟踪和视频分割。

Data Annotations

SAM在数据注释方面发挥了很大的作用,许多实验采用SAM生成新的数据集,包括SAM-1b本身也是SAM参与了数据集生成。
视频文本和遥感图像分割数据正在尝试使用SAM来帮助生成图像注释。SAM也参与到了高质量伪标签的生成中,使得这个过程变得非常简单、快速和高效。(也许可以参与到生成对抗神经网络?)

Beyond Vision

3D Reconstruction

SAM是2D图像分割的SOTA方法,但不能直接用于3D场景理解。文章提出利用nerf将SAM的分割能力扩展到3D场景。

Non-Euclidean Domain

SNA范式利用SAM引入了专门的可伸缩图卷积层和元学习策略来进行图分析。

Robotics

SAM的图分割正参与到机器人的实体定位。

Video Text Spotting

SAM模型应用于边界框注释,能够为大规模视频文本数据集生成掩码注释。SAM text首先从现有注释或者场景文本检测模型中提取边界框坐标,将其用作SAM模型获取掩码标签的输入提示符。

Vision and Language

主要应用于语言指导图像分割,比如字幕生成。

Audio and Vision

利用视听进行定位。

Multimodal Visualization and Open-Vocabulary Interactive Segmentation

交互式分割,SAM本身也就是基于交互式分割方法进行训练的。

更多方向

Weakly-Supervised Semantic Segmentation

SAM具有很强的适用性,可以不需要模型微调就生成训练标签和分割蒙版。

Adversarial Robustness

SAM具有黑盒鲁棒性,但是白盒脆弱性。需要深入研究相关安全问题。

One-shot

在单次学习中使用SAM可以允许仅使用单词数据就可以有效地创建个性化分割模型。

Explainable AI

(不太理解)

Conlusion

本文首次全面回顾了计算机视觉及其他领域SAM基础模型的最新进展。首先,总结了基础模型的发展历史,包括大型语言模型模型、大型可视化模型和大型多模态模型,以及SAM的基本术语。特别关注SAM在各种任务和数据类型中的应用,我们总结和比较了SAM的并发工作及其后续工作。然后,讨论了SAM在广泛的图像处理应用中的巨大潜力,包括软件场景、现实场景和复杂场景。并对其进行了分析和总结跨各种应用程序的SAM的优点和局限性。

本人对SATIR比较感兴趣,这个异常分割一是有较好的性能,而是方向比较重要。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值