本文基于SAM之MaskDecoder总结(个人研究)_sam maskdecoder-CSDN博客进行总结。
经过两个叠加的双向注意力模块和最后一个token-image交叉注意力后,输出Query和Keys(具体过程可见上一篇文章中TwoWayTransformer代码)。其中,Query中的第一个向量为下图中的iou_out,后面的向量为mask_out。Keys为下图中的src。
class MaskDecoder(nn.Module):
# 定义一个用于预测掩码的解码器类,基于 Transformer 架构
def __init__(
self,
*,
transformer_dim: int, # Transformer 的嵌入维度
transformer: nn.Module, # 用于掩码预测的 Transformer 模块
num_multimask_outputs: int = 3, # 多掩码输出的数量,默认值为 3
activation: Type[nn.Module] = nn.GELU, # 上采样掩码时使用的激活函数类型,默认为 GELU
iou_head_depth: int = 3, # 用于预测掩码质量的 MLP 的深度,默认为 3
iou_head_hidden_dim: int = 256, # 用于预测掩码质量的 MLP 的隐藏层维度,默认为 256
) -> None:
"""
使用 Transformer 架构根据图像和提示嵌入预测掩码。
参数:
transformer_dim (int)