SMA2:代码实现详解——Image Encoder篇(FpnNeck章)

SMA2:代码实现详解——Image Encoder篇(FpnNeck)

在这里插入图片描述

总配置YAML文件、OmegaConf和hydra

SAM2的官方实现是使用yaml文件来配置整体的模型结构与参数的。关键代码如下:

def build_sam2(
    config_file,
    ckpt_path=None,
    device="cuda",
    mode="eval",
    hydra_overrides_extra=[],
    apply_postprocessing=True,
):

    if apply_postprocessing:
        hydra_overrides_extra = hydra_overrides_extra.copy()
        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
        ]
    # Read config and init model
    cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
    OmegaConf.resolve(cfg)
    model = instantiate(cfg.model, _recursive_=True)
    _load_checkpoint(model, ckpt_path)
    model = model.to(device)
    if mode == "eval":
        model.eval()
    return model

从代码的第10行到第20行都是在配置模型参数。第19行的compose函数与第21行的instantiate函数都是hydra库的库函数。Hydra是一个开源Python框架,也是由Meta团队开发的,它可简化研究和其他复杂应用程序的开发。其主要功能是能够通过组合动态创建分层配置,并通过配置文件和命令行覆盖它。Hydra对yaml文件的读写操作是基于OmegaConf库的。

回到我们的代码,第19行的compose函数用来读取config_name参数指定的yaml文件,生成可类似于Dict访问的Python对象,并根据overrides参数的内容,覆盖从yaml得到的部分参数内容。

第21行的instantiate函数根据yaml文件中的配置信息实际构建网络模型。这个地方只用文字可能不太好理解,我们举个例子:
例子yaml文件:

optimizer:
  _target_: my_app.Optimizer
  algo: SGD
  lr: 0.01

例子class文件:

class Optimizer:
    algo: str
    lr: float

    def __init__(self, algo: str, lr: float) -> None:
        self.algo = algo
        self.lr = lr

例子实例化函数:

opt = instantiate(cfg.optimizer)
print(opt)
# Optimizer(algo=SGD,lr=0.01)

# override parameters on the call-site
opt = instantiate(cfg.optimizer, lr=0.2)
print(opt)
# Optimizer(algo=SGD,lr=0.2)

那么我们接下来见一下SMA2的具体构造(以tiny版本为例):

model:
  _target_: sam2.modeling.sam2_base.SAM2Base
  image_encoder:
    _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
    scalp: 1
    trunk:
      _target_: sam2.modeling.backbones.hieradet.Hiera
      embed_dim: 96
      num_heads: 1
      stages: [1, 2, 7, 2]
      global_att_blocks: [5, 7, 9]
      window_pos_embed_bkg_spatial_size: [7, 7]
    neck:
      _target_: sam2.modeling.backbones.image_encoder.FpnNeck
      position_encoding:
        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
        num_pos_feats: 256
        normalize: true
        scale: null
        temperature: 10000
      d_model: 256
      backbone_channel_list: [768, 384, 192, 96]
      fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
      fpn_interp_model: nearest

  memory_attention:
    _target_: sam2.modeling.memory_attention.MemoryAttention
    d_model: 256
    pos_enc_at_input: true
    layer:
      _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
      activation: relu
      dim_feedforward: 2048
      dropout: 0.1
      pos_enc_at_attn: false
      self_attention:
        _target_: sam2.modeling.sam.transformer.RoPEAttention
        rope_theta: 10000.0
        feat_sizes: [32, 32]
        embedding_dim: 256
        num_heads: 1
        downsample_rate: 1
        dropout: 0.1
      d_model: 256
      pos_enc_at_cross_attn_keys: true
      pos_enc_at_cross_attn_queries: false
      cross_attention:
        _target_: sam2.modeling.sam.transformer.RoPEAttention
        rope_theta: 10000.0
        feat_sizes: [32, 32]
        rope_k_repeat: True
        embedding_dim: 256
        num_heads: 1
        downsample_rate: 1
        dropout: 0.1
        kv_in_dim: 64
    num_layers: 4

  memory_encoder:
      _target_: sam2.modeling.memory_encoder.MemoryEncoder
      out_dim: 64
      position_encoding:
        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
        num_pos_feats: 64
        normalize: true
        scale: null
        temperature: 
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值