SMA2:代码实现详解——Image Encoder篇(FpnNeck)
总配置YAML文件、OmegaConf和hydra
SAM2的官方实现是使用yaml文件来配置整体的模型结构与参数的。关键代码如下:
def build_sam2(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
):
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
]
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)
if mode == "eval":
model.eval()
return model
从代码的第10行到第20行都是在配置模型参数。第19行的compose
函数与第21行的instantiate
函数都是hydra库
的库函数。Hydra是一个开源Python框架,也是由Meta团队开发的,它可简化研究和其他复杂应用程序的开发。其主要功能是能够通过组合动态创建分层配置,并通过配置文件和命令行覆盖它。Hydra对yaml文件的读写操作是基于OmegaConf库
的。
回到我们的代码,第19行的compose
函数用来读取config_name
参数指定的yaml文件,生成可类似于Dict访问的Python对象,并根据overrides
参数的内容,覆盖从yaml得到的部分参数内容。
第21行的instantiate
函数根据yaml文件中的配置信息实际构建网络模型。这个地方只用文字可能不太好理解,我们举个例子:
例子yaml文件:
optimizer:
_target_: my_app.Optimizer
algo: SGD
lr: 0.01
例子class文件:
class Optimizer:
algo: str
lr: float
def __init__(self, algo: str, lr: float) -> None:
self.algo = algo
self.lr = lr
例子实例化函数:
opt = instantiate(cfg.optimizer)
print(opt)
# Optimizer(algo=SGD,lr=0.01)
# override parameters on the call-site
opt = instantiate(cfg.optimizer, lr=0.2)
print(opt)
# Optimizer(algo=SGD,lr=0.2)
那么我们接下来见一下SMA2的具体构造(以tiny版本为例):
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 7, 2]
global_att_blocks: [5, 7, 9]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [768, 384, 192, 96]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: