【代码学习】Segment Anything 2-代码学习

SAM2-代码学习(部分)

一、项目框架

├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── INSTALL.md
├── LICENSE
├── LICENSE_cctorch
├── MANIFEST.in
├── README.md
├── RELEASE_NOTES.md
├── assets
├── backend.Dockerfile
├── checkpoints
├── demo
├── docker-compose.yaml
├── notebooks
├── pyproject.toml
├── sam2
├── sav_dataset
├── setup.py
├── tools
└── training

1. 根目录

  • 包含项目的主要配置文件和文档,如 README.mdLICENSEINSTALL.md 等。
  • backend.Dockerfiledocker-compose.yaml:用于容器化和编排。
    • yaml文件格式:YAML(YAML Ain’t Markup Language)是一种人类可读的数据序列化格式,广泛用于配置文件和数据交换。它的设计目标是易于阅读和编写,同时支持复杂的数据结构。
  • pyproject.tomlsetup.py
    • setup.py文件的作用包括:
      1. 项目元数据
        • 定义项目的名称、版本、描述、URL、作者信息、许可证等。
      2. 读取长描述
        • README.md文件中读取项目的长描述,用于在 PyPI 上显示项目的详细信息。
      3. 依赖项管理
        • 定义项目的必需依赖项和可选依赖项。
      4. 扩展模块
        • 定义和构建 C/C++ 扩展模块,特别是 CUDA 扩展。
      5. 环境变量配置
        • 通过环境变量控制 CUDA 扩展的构建和错误处理行为。
      6. 安装配置
        • 使用 setup 函数来配置项目的安装选项,包括包的发现、依赖项的安装、扩展模块的构建等。

2. assets(资源)

  • 存放项目相关的图片和其他静态资源。

3. checkpoints

  • 包含下载模型检查点的脚本。

4. demo文件夹

  • 包含演示项目的代码和资源。
  • backend:后端代码。
  • frontend:前端代码。
  • data:演示用的数据文件。

5. notebooks文件夹

  • 包含 Jupyter Notebook 示例,用于展示项目的使用方法。
  • imagesvideos:示例用的图片和视频文件。

6. sam2文件夹

  • 项目的核心代码库。

  • configs:配置文件。

    • 通常用于定义模型训练和推理的各种参数。以下是这些配置文件中可能包含的内容:
      1. 模型架构参数
        • 模型的层数、每层的神经元数量、激活函数等。
        • 例如:sam2_hiera_b+.yamlsam2_hiera_l.yaml 等。
      2. 训练参数
        • 学习率、优化器类型、批量大小、训练轮数等。
        • 例如:sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
      3. 数据集参数
        • 数据集路径、数据预处理方法、数据增强策略等。
      4. 损失函数参数
        • 损失函数的类型及其相关参数。
      5. 评估参数
        • 评估指标、评估频率等。
      6. 其他超参数
        • 例如 dropout 率、正则化参数等。
  • csrc:C++ 和CUDA源代码,用于实现高性能计算

  • modeling

      1. __init__.py
      • 作用:初始化模块,使得该文件夹可以被视为一个 Python 包。通常用于导入包中的子模块或子包。
      1. sam2_base.py
      • 作用:定义基础模型类 SAM2Base,这是模型的核心类,包含了模型的基本架构和操作。其他模型类可以继承这个基础类并扩展其功能。
      1. sam2_utils.py
      • 作用:包含一些实用工具函数,这些函数在模型的实现和训练过程中被频繁使用。例如,位置编码、点采样、帧选择等功能。
      1. backbones 文件夹
      • 作用 :包含模型的骨干网络(backbone)实现。骨干网络是模型的基础部分,用于提取输入数据的特征。

        • __init__.py:初始化 backbones 模块。
        • image_encoder.py:定义图像编码器,用于从输入图像中提取特征。
        • hieradet.py:定义 Hiera 模型,这是一个特定的骨干网络架构,用于特征提取。
        • fpn_neck.py:定义特征金字塔网络(FPN)颈部,用于多尺度特征融合。
      1. memory_attention 文件夹
      • 作用:包含与记忆注意力机制相关的实现。记忆注意力机制用于在视频处理过程中记住和利用之前帧的信息。

        • __init__.py:初始化 memory_attention 模块。
        • memory_attention.py:定义记忆注意力机制的核心类和方法。
      1. position_encoding 文件夹
      • 作用 :包含位置编码的实现。位置编码用于在模型中引入位置信息,帮助模型理解输入数据的空间结构。

        • __init__.py:初始化 position_encoding 模块。
        • position_encoding.py:定义位置编码的具体实现。
      1. sam 文件夹
      • 作用:包含与SAM模型相关的实现。
        • __init__.py:初始化 sam 模块。
        • transformer.py:定义 SAM 模型中的变换器(Transformer)模块,用于处理序列数据。
      1. utils 文件夹
      • 作用:包含各种实用工具函数和辅助代码,这些代码在项目的不同部分中被频繁使用。

        • __init__.py:初始化 utils 模块。
        • misc.py:包含一些通用的实用函数。
        • data_utils.py:包含用于数据处理和转换的工具函数。
        • distributed.py:包含用于分布式训练的工具函数。
        • logger.py:包含用于日志记录的工具函数。
        • train_utils.py:包含用于训练过程中的辅助工具函数。
  • utils:实用工具函数。

7. sav_dataset文件夹

  • 与数据集相关的代码和资源。
  • example:数据集示例文件。
  • utils:数据集处理工具。

8. tools文件夹

  • 文件夹提供了一些工具脚本,主要用于在不同的数据集上进行半监督视频对象分割(VOS)推理和评估。通过这些脚本,用户可以方便地生成预测结果,并使用相应的数据集评估工具或服务器来评估模型的性能。

9. training文件夹

  • 包含训练相关的代码和资源。

  • assets:用于存放与训练过程相关的各种资源文件,以支持和配置训练过程

  • dataset:数据集处理代码。

  • model:对modeling的模型主体的扩展

    • 初始化参数

      • 扩展了模型的初始化参数,以支持训练过程中所需的各种配置和超参数。这些参数包括训练和评估过程中使用的概率、帧数、点采样方法等。
    • 前向传播方法

      • 定义了 forward 方法,用于处理输入数据并执行前向传播。这个方法会根据训练和评估的不同需求,预先计算图像特征或在需要时计算图像特征。
    • 特征准备

      • 定义了 _prepare_backbone_features_per_frame 方法,用于在需要时计算图像特征,并避免重复计算。
    • 输入准备

      • 定义了 prepare_prompt_inputs 方法,用于准备输入的掩码、点或框提示,并根据配置随机选择初始条件帧和需要添加校正点的帧。
    • 视频跟踪

      • 定义了 forward_tracking 方法,用于在每一帧上执行视频跟踪,并在需要时采样校正点。
    • 跟踪步骤

      • 定义了 track_step 方法,用于在每一帧上执行跟踪步骤,并根据输入提示和前一帧的输出生成当前帧的输出。
    • 迭代校正点采样

      • 定义了 _iter_correct_pt_sampling 方法,用于在训练过程中迭代采样校正点,并在每次采样后更新预测掩码。
    • 内存编码

      • track_step 方法中,调用 _encode_memory_in_output 方法,将预测掩码编码为新的内存特征,以便在未来的帧中使用。
  • scripts:从视频文件中提取帧,并将这些帧保存为 JPEG 图像

  • utils

    • 数据处理工具
      • 包含用于数据处理和转换的工具函数。例如,data_utils.py 文件可能包含用于加载和预处理数据的函数。
    • 训练辅助工具
      • 包含用于训练过程中的辅助工具函数。例如,train_utils.py 文件可能包含用于训练过程中的日志记录、检查点保存和加载等功能的函数。
    • 分布式训练工具
      • 包含用于分布式训练的工具函数。例如,distributed.py 文件可能包含用于设置和管理分布式训练环境的函数。
    • 日志记录工具
      • 包含用于日志记录的工具函数。例如,logger.py 文件可能包含用于设置和管理日志记录的函数。
    • 其他实用工具
      • 包含其他各种实用工具函数,这些函数在项目的不同部分中被频繁使用。例如,misc.py 文件可能包含一些通用的实用函数
  • train.pytrainer.py

    • 初始化和配置

      • train.py文件解析命令行参数和配置文件,初始化训练配置,并将这些配置传递给 Trainer类。
      • Trainer 类根据传递的配置进行初始化,包括设置环境变量、设备、分布式训练、日志记录、随机种子等。
    • 训练和验证

      • train.py 文件调用 Trainer类的 run方法来执行训练和验证过程。
      • Trainer类的run方法根据模式(训练、验证或仅训练)执行相应的训练和验证过程,包括数据加载、前向传播、损失计算、梯度更新和日志记录。
    • 检查点管理

      • Trainer类负责保存和加载训练检查点,以便在训练中断后恢复训练。
      • train.py文件可以通过命令行参数或配置文件指定检查点路径,以便在启动时加载检查点。

二、代码详解(部分)

代码细节以注释为主,补充内容为辅

1.train.py

(1)库的导入
import logging #python的一个标准库,用于记录日志信息。
import os
import random
import sys #标准库,用于提供对Python解释器的访问和控制。
import traceback #标准库,用于提供异常处理功能。
from argparse import ArgumentParser #用于解析命令行参数和选项的标准模块。

import submitit #submitit是一个用于在集群上运行Python代码的库。--便于分布式训练
import torch 

from hydra import compose, initialize_config_module #Hydra是一个用于构建复杂应用程序的框架,它允许将配置与代码分离。
from hydra.utils import instantiate 

from iopath.common.file_io import g_pathmgr #iopath是一个用于处理文件路径的库,提供了一种统一的接口,可以处理本地文件系统、HDFS、S3等不同的文件系统。
from omegaconf import OmegaConf #OmegaConf是一个用于处理配置文件的库,它提供了一种简单的方式来处理配置文件。

from training.utils.train_utils import makedir, register_omegaconf_resolvers #自定义的函数,用于创建文件夹和注册OmegaConf解析器。

#设置环境变量
os.environ["HYDRA_FULL_ERROR"] = "1" #设置环境变量,用于控制Hydra的错误输出。
<1> 标准库:
  • logging:用于记录日志信息。

  • os:提供与操作系统交互的功能,如环境变量、文件路径等。

    • os.environ是os 模块中的一个字典对象,它包含了当前操作系统环境中的所有环境变量。

      • 环境变量:

        1. PATH:指定可执行文件的搜索路径。当你在命令行中输入一个命令时,操作系统会在 PATH 环境变量指定的目录中搜索该命令的可执行文件。

        2. HOME:当前用户的主目录路径。

        3. USER:当前登录的用户名。

        4. SHELL:当前用户的默认 shell 程序。

        5. LANG:系统的语言和区域设置。

        6. PWD:当前工作目录。

        7. TEMPTMP:临时文件目录。

        8. EDITOR:默认文本编辑器。

        9. PYTHONPATH:Python 解释器搜索模块的路径。用于指定额外的 Python 模块搜索路径。

        10. LD_LIBRARY_PATH:指定动态链接库的搜索路径(在 Linux 系统中)。

        11. LOGNAME:当前登录的用户名,通常与 USER 相同。

        12. HOSTNAME:当前主机的名称。

        13. MAIL:当前用户的邮件存放目录。

        14. TERM:终端类型。

        15. DISPLAY:显示服务器的地址,用于图形界面程序。

  • random:用于生成随机数。

  • sys:提供对 Python 解释器的访问和控制。

  • traceback:用于处理和打印异常的堆栈跟踪。

  • argparse:用于解析命令行参数和选项的标准模块。

<2>第三方库
  • submitit:用于在集群上运行 Python 代码,便于分布式训练
  • torch:PyTorch 库
  • hydra
    • compose:用于组合和加载配置文件。
    • initialize_config_module:此 函数的主要作用是告诉 Hydra 从哪个模块或目录中查找配置文件。通过调用这个函数,可以指定配置文件所在的模块或目录,使得 Hydra 可以正确地找到并加载这些配置文件
    • instantiate:用于实例化配置对象。 --在这个文件中,可用于实例化模型对应的类
  • iopath
    • g_pathmgr:用于处理文件路径,提供统一的接口,可以处理本地文件系统、HDFS、S3 等不同的文件系统
  • omegaconf
    • OmegaConf:用于处理配置文件,提供简单的方式来处理配置文件。 (如:配置对象的类型检查,类型转换)
<3>自定义模块
  • training.utils.train_utils
    • makedir:用于创建文件夹。
    • register_omegaconf_resolvers:用于注册 OmegaConf 解析器。
      • 注册解析器的过程就是将这些自定义的函数或表达式注册到 OmegaConf 中,使得它们可以在配置文件中被识别和使用。
      • 通过注册 OmegaConf 解析器,你可以在配置文件中使用自定义的表达式和函数来动态解析和计算配置项的值。然后,Hydra 会使用这些解析器来解析配置文件中的相应表达式,从而使配置文件更加灵活和动态。
(2)训练函数
def single_proc_run(local_rank, main_port, cfg, world_size): #单进程运行
    """Single GPU process"""
    os.environ["MASTER_ADDR"] = "localhost" #设定主节点地址为本台主机
    os.environ["MASTER_PORT"] = str(main_port) #设定主节点用于进程通信的端口号
    os.environ["RANK"] = str(local_rank) #设定当前进程的全局rank
    os.environ["LOCAL_RANK"] = str(local_rank) #设定当前进程的本地rank
    os.environ["WORLD_SIZE"] = str(world_size) #设定当前进程的全局大小
    try:
        register_omegaconf_resolvers()#注册OmegaConf解析器
    except Exception as e: #捕获异常
        logging.info(e)

    trainer = instantiate(cfg.trainer, _recursive_=False) #实例化训练器
    trainer.run()

def single_node_runner(cfg, main_port: int):#单节点运行
    assert cfg.launcher.num_nodes == 1 #确保节点数为1--assert是Python中的断言语句,用于判断一个表达式,在表达式条件为False的时候触发异常。
    num_proc = cfg.launcher.gpus_per_node #获取每个节点的GPU数量
    torch.multiprocessing.set_start_method(#设置多进程启动方法
        "spawn"
    )  # CUDA runtime does not support `fork`
    if num_proc == 1:
        # directly call single_proc so we can easily set breakpoints
        #直接调用single_proc,以便轻松设置断点,从而便于调试
        # mp.spawn does not let us set breakpoints
        # mp.spawn不允许我们设置断点(就是multiprocessing.spawn)
        single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
    else:
        mp_runner = torch.multiprocessing.start_processes
        args = (main_port, cfg, num_proc)
        # Note: using "fork" below, "spawn" causes time and error regressions. Using
        # spawn changes the default multiprocessing context to spawn, which doesn't
        # interact well with the dataloaders (likely due to the use of OpenCV).
        # 笔记:在下面使用“fork”,“spawn”会导致时间和错误回归。使用spawn会将默认的多进程上下文更改为spawn,这与数据加载器不兼容(可能是由于使用OpenCV)。--所以要和CUDA不兼容做权衡取舍
        mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="spawn")
<1>single_proc_run中,不同环境变量的作用
  1. MASTER_ADDR:主节点的地址
    • 在分布式训练中,主节点负责协调和管理所有工作节点(Worker Nodes)。这里将主节点地址设置为 localhost,表示当前机器是主节点。
  2. MASTER_PORT:主节点用于进程通信的端口号
    • 主节点和工作节点之间需要通过网络进行通信,这里设置了主节点的通信端口号。
  3. RANK:就是进程的全局排名,但不是优先级
    • 在分布式训练中,每个进程都有一个唯一的全局排名(Rank),用于标识进程的身份。这里将 local_rank赋值给 RANK 环境变量。
  4. LOCAL_RANK:指定当前进程在本地机器上的排名
    • 在多 GPU 训练中,每个 GPU 都有一个本地排名(Local Rank),用于标识进程在本地机器上的身份。这里将 local_rank赋值给 LOCAL_RANK 环境变量。
  5. WORLD_SIZE:指定总的进程数
  • local rank和rank的区别:

    • 假设你有一个分布式训练系统,有 2 台机器,每台机器上有 4 个 GPU,总共有 8 个进程:

      • 机器 1:

        • 进程 0:RANK=0, LOCAL_RANK=0
        • 进程 1:RANK=1, LOCAL_RANK=1
        • 进程 2:RANK=2, LOCAL_RANK=2
        • 进程 3:RANK=3, LOCAL_RANK=3
      • 机器 2:

        • 进程 4:RANK=4, LOCAL_RANK=0
        • 进程 5:RANK=5, LOCAL_RANK=1
        • 进程 6:RANK=6, LOCAL_RANK=2
        • 进程 7:RANK=7, LOCAL_RANK=3
      • 在这个示例中,RANK 用于标识每个进程在整个分布式系统中的位置,而 LOCAL_RANK 用于标识每个进程在本地机器上的位置。

      • 在项目代码中,因为就一个GPU所以local_rank的和rank值相同

<2>训练器
  • 训练器(trainer)指的是一个用于管理和执行训练过程的对象。训练器通常包含了训练模型所需的所有逻辑,包括数据加载、模型前向传播、损失计算、反向传播和优化等步骤。
  • recursive参数
    • 这是一个可选参数,表示是否递归地实例化配置项。默认值是 True,表示递归地实例化配置项。如果设置为 False,则表示不递归地实例化配置项,这样内层的配置项不会被实例化,而是以字典形式传递
    • 递归实例化的意思是,如果配置项中包含其他配置项,Hydra 会递归地实例化这些嵌套的配置项。例如,如果一个配置项包含另一个配置项,Hydra 会首先实例化嵌套的配置项,然后再实例化外层的配置项。
    • 这里设置为false的主要原因应该是为了减少计算开销
<3>torch的多进程启动
torch.multiprocessing.set_start_method(#设置多进程启动方法
        "spawn"
    )  # CUDA runtime does not support `fork`
  • Python 的multiprocessing模块支持多种启动新进程的方法,主要包括:

    1. fork:父进程被复制(包括内存中的所有内容),子进程几乎是父进程的一个克隆。fork 是 Unix 系统的默认启动方法。
    2. spawn:父进程启动一个全新的 Python 解释器进程。子进程只继承运行时的必要资源,不会继承父进程的内存内容。spawn 是 Windows 系统的默认启动方法。
    3. forkserver:启动一个服务器进程,父进程和子进程通过服务器进程进行通信。
  • 为什么用spawn

    • CUDA 运行时不支持 fork 方法,因为 fork 会复制父进程的内存内容,包括 CUDA 上下文,这可能导致 CUDA 资源的竞争和不一致。因此,在使用 CUDA 时,推荐使用 spawn 方法来启动新进程。
  • torch.multiprocessing模块作用

    • torch.multiprocessing提供了与 Python 标准库 multiprocessing类似的接口和功能,包括:

      1. 进程管理:创建和管理多个进程。
      2. 进程间通信:通过管道(Pipe)和队列(Queue)在进程之间传递数据。
      3. 共享内存:在多个进程之间共享张量(Tensor)和其他数据。
      4. 同步机制:提供锁(Lock)、事件(Event)、信号量(Semaphore)等同步原语。
  • torch.multiprocessing.start_processes的参数

    • fn:要在每个进程中运行的目标函数。
    • args:传递给目标函数的参数元组。
    • nprocs:要启动的进程数。
    • join:是否等待所有进程完成。默认为 True
    • daemon:是否将进程设置为守护进程。默认为 False
    • start_method:启动新进程的方法。可以是 'fork''spawn''forkserver'。默认为 'spawn'
(3)格式化异常
def format_exception(e: Exception, limit=20): #格式化异常,限制堆栈深度为20
    traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit)) #获取异常的堆栈信息,并将其格式化为字符串
    return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}" #返回异常的类型和信息,以及堆栈信息
  • join是 Python 字符串的一个方法,用于将一个可迭代对象(如列表、元组等)中的元素连接成一个字符串。

    separator.join(iterable)
    
    • separator:用于分隔每个元素的字符串。
    • iterable:一个可迭代对象,其元素将被连接成一个字符串。
  • f-string方法

    • f"{…}":这是 f-string 的语法,允许在字符串中嵌入表达式
(4)Submitit运行器类
class SubmititRunner(submitit.helpers.Checkpointable): #SubmititRunner类继承自submitit.helpers.Checkpointable类
    #Checkpointable是submitit的一个类,用于支持checkpointing功能,即在任务执行过程中定期保存任务状态,以便在任务失败时恢复任务状态。
    """A callable which is passed to submitit to launch the jobs."""
    #一个可调用对象,用于传递给submitit以启动作业。
    def __init__(self, port, cfg):# 构造器
        self.cfg = cfg
        self.port = port
        self.has_setup = False

    def run_trainer(self): #运行训练器,用于设置环境变量,并启动训练过程--用于在分布式环境中启动训练过程
        job_env = submitit.JobEnvironment()
        #获取当前作业的环境信息
        # Need to add this again so the hydra.job.set_env PYTHONPATH
        # is also set when launching jobs.
        #需要再次添加这个,以便在启动作业时也设置hydra.job.set_env PYTHONPATH。
        add_pythonpath_to_sys_path()#将PYTHONPATH添加到sys.path中
        os.environ["MASTER_ADDR"] = job_env.hostnames[0] #设定主节点地址
        os.environ["MASTER_PORT"] = str(self.port) #设定主节点端口
        os.environ["RANK"] = str(job_env.global_rank)#设定全局rank
        os.environ["LOCAL_RANK"] = str(job_env.local_rank)#设定本地rank
        os.environ["WORLD_SIZE"] = str(job_env.num_tasks)#设定全局大小

        register_omegaconf_resolvers()#注册OmegaConf解析器
        cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)#将配置转换为容器
        #表示将OmegaConf对象转换为Python对象,resolve=False表示不解析插值和引用
        cfg_resolved = OmegaConf.create(cfg_resolved)#创建配置文件

        trainer = instantiate(cfg_resolved.trainer, _recursive_=False)
        trainer.run()

    def __call__(self): #__call__ 方法允许类的实例具有“可调用”的行为,即将对象当作函数来使用。
        #当实例被当作函数调用时,会执行__call__中的所有代码
        job_env = submitit.JobEnvironment()
        self.setup_job_info(job_env.job_id, job_env.global_rank)
        try:
            self.run_trainer()#运行训练器
        except Exception as e:
            # Log the exception. Then raise it again (as what SubmititRunner currently does).
            message = format_exception(e)
            logging.error(message)
            raise e

    def setup_job_info(self, job_id, rank):#设置作业信息
        """Set up slurm job info"""
        self.job_info = {
            "job_id": job_id,
            "rank": rank,#全局rank
            "cluster": self.cfg.get("cluster", None),#获取集群信息
            "experiment_log_dir": self.cfg.launcher.experiment_log_dir,#获取实验日志目录
        }

        self.has_setup = True #作业信息已设置
<1>pythonpath和sys.path
  • pythonpath是一个环境变量,它告诉Python解释器在哪里查找模块。
  • sys.path是一个包含字符串的列表,它包含了一个Python解释器在运行时查找模块的目录的列表,包括:
    • 当前脚本所在的目录:这是 Python 解释器启动时所在的目录。
    • 标准库目录:包含 Python 标准库的目录。
    • 安装的第三方包目录:包含通过包管理工具(如 pip)安装的第三方包的目录。
    • PYTHONPATH 环境变量指定的目录:这些目录是通过设置 PYTHONPATH 环境变量添加的。
<2>submitit.JobEnvironment():
  • 一些常见属性:
    1. hostnames:一个列表,包含当前作业运行的所有主机名。
    2. global_rank:当前进程在所有进程中的全局排名。
    3. local_rank:当前进程在本地机器上的排名。
    4. num_tasks:参与作业的总任务数。
    5. job_id:当前作业的 ID。
<3>插值和引用

在run_trainer()中,对Omegaconf解析器进行了新的参数配置

register_omegaconf_resolvers()#注册OmegaConf解析器
cfg_resolved = OmegaConf.to_container(self.cfg, resolve=False)#将配置转换为容器
#表示将OmegaConf对象转换为Python对象,resolve=False表示不解析插值和引用
cfg_resolved = OmegaConf.create(cfg_resolved)#创建配置文件
  • 代码流程

    • 将OmegaConf对象转换成python容器
    • 再将python容器转换回配置对象
  • 为什么这么做?

    • 深拷贝:通过这种方式,可以创建配置对象的一个深拷贝,确保原始配置对象不被修改。
    • 清理状态:有时,配置对象可能包含一些临时状态或未解析的引用,通过这种转换可以清理这些状态。
    • 兼容性:在某些情况下,可能需要将配置对象转换为普通容器进行某些操作,然后再转换回 OmegaConf 配置对象以继续使用 OmegaConf 的特性。
  • resolve=False:不解析插值和引用

    • 插值是指在配置文件中使用占位符来引用其他配置项的值。占位符通常以特定的语法表示,例如 ${}。当配置文件被解析时,占位符会被替换为实际的值。

      • 示例:
      database:
       host: localhost
       port: 5432
       url: postgresql://${database.host}:${database.port}/mydb
      
      • 在这个示例中,url 配置项使用了插值 ${database.host}${database.port} 来引用 hostport的值。当配置文件被解析时,url 的值会被替换为 postgresql://localhost:5432/mydb
    • 引用是指在配置文件中使用变量来引用其他配置项的值。引用通常用于在多个地方使用相同的配置值,从而避免重复定义。

      • 示例
      defaults:
        - &default_host localhost
        - &default_port 5432
      
      database:
        host: *default_host
        port: *default_port
        url: postgresql://*default_host:*default_port/mydb
      
      • 在这个示例中,hostport配置项使用了引用 *default_host*default_port 来引用默认的主机和端口值。当配置文件被解析时,hostport 的值会被替换为 localhost5432
(5)main主函数
def main(args) -> None:
    cfg = compose(config_name=args.config) #使用Hydra的compose函数加载配置文件
    if cfg.launcher.experiment_log_dir is None: #如果实验日志目录为空
        cfg.launcher.experiment_log_dir = os.path.join(
            os.getcwd() , "sam2_logs", args.config 
            #os.getcwd()获取当前工作目录
        )
    print("###################### Train App Config ####################")
    print(OmegaConf.to_yaml(cfg)) #将配置文件转换为yaml格式
    print("############################################################")

    add_pythonpath_to_sys_path()
    makedir(cfg.launcher.experiment_log_dir) #创建实验日志目录
    with g_pathmgr.open(#使用iopath打开文件---保存原始配置
        os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
    ) as f:
        f.write(OmegaConf.to_yaml(cfg))#将配置文件写入文件

    cfg_resolved = OmegaConf.to_container(cfg, resolve=False)
    cfg_resolved = OmegaConf.create(cfg_resolved)

    with g_pathmgr.open(#保存解析后的配置文件
        os.path.join(cfg.launcher.experiment_log_dir, "config_resolved.yaml"), "w"
    ) as f:
        f.write(OmegaConf.to_yaml(cfg_resolved, resolve=True))#将解析后的配置文件写入文件

    submitit_conf = cfg.get("submitit", None)#获取submitit配置--即集群的配置,
    assert submitit_conf is not None, "Missing submitit config"#确保submitit配置不为空

    submitit_dir = cfg.launcher.experiment_log_dir#获取实验日志目录
    submitit_dir = os.path.join(submitit_dir, "submitit_logs")#将submitit日志目录添加到实验日志目录中
    # Priotrize cmd line args
    cfg.launcher.gpus_per_node = (#获取GPU数量
        args.num_gpus if args.num_gpus is not None else cfg.launcher.gpus_per_node
        #三元表达式,如果args.num_gpus(命令行中参数)不为空,则使用args.num_gpus,否则使用cfg.launcher.gpus_per_node
    )
    cfg.launcher.num_nodes = (#获取节点数量
        args.num_nodes if args.num_nodes is not None else cfg.launcher.num_nodes
    )
    submitit_conf.use_cluster = (#是否使用集群
        args.use_cluster if args.use_cluster is not None else submitit_conf.use_cluster
    )
    if submitit_conf.use_cluster:
        executor = submitit.AutoExecutor(folder=submitit_dir)#创建submitit执行器
        #folder是一个字符串,指定了存储作业输出的目录。如果未指定,则使用当前目录。
        submitit_conf.partition = (
            args.partition
            if args.partition is not None
            else submitit_conf.get("partition", None)
        )
        submitit_conf.account = (#获取账户
            args.account
            if args.account is not None
            else submitit_conf.get("account", None)
        )
        submitit_conf.qos = (#获取QOS
            args.qos if args.qos is not None else submitit_conf.get("qos", None)
        )
        #QOS是Quality of Service的缩写,意为服务质量。在SLURM中,QOS是用来控制作业的优先级的。
        job_kwargs = {#作业参数
            "timeout_min": 60 * submitit_conf.timeout_hour,#超时时间
            "name": (
                submitit_conf.name if hasattr(submitit_conf, "name") else args.config#作业名称
            ),
            "slurm_partition": submitit_conf.partition,#SLURM分区
            "gpus_per_node": cfg.launcher.gpus_per_node,#每个节点的GPU数量
            "tasks_per_node": cfg.launcher.gpus_per_node,  # one task per GPU
            "cpus_per_task": submitit_conf.cpus_per_task,#每个任务的CPU数量
            "nodes": cfg.launcher.num_nodes,#节点数量
            "slurm_additional_parameters": {
                "exclude": " ".join(submitit_conf.get("exclude_nodes", [])),#排除节点
            },
        }
        if "include_nodes" in submitit_conf:
            assert (
                len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes#确保节点数量足够
            ), "Not enough nodes"
            job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(#要包含的节点列表
                submitit_conf["include_nodes"]
            )
        if submitit_conf.account is not None:
            job_kwargs["slurm_additional_parameters"]["account"] = submitit_conf.account
        if submitit_conf.qos is not None:
            job_kwargs["slurm_additional_parameters"]["qos"] = submitit_conf.qos

        if submitit_conf.get("mem_gb", None) is not None:
            job_kwargs["mem_gb"] = submitit_conf.mem_gb#内存大小--单位GB
        elif submitit_conf.get("mem", None) is not None:
            job_kwargs["slurm_mem"] = submitit_conf.mem#内存大小--单位MB

        if submitit_conf.get("constraints", None) is not None:#一些约束--如指定GPU型号
            job_kwargs["slurm_constraint"] = submitit_conf.constraints

        if submitit_conf.get("comment", None) is not None:
            job_kwargs["slurm_comment"] = submitit_conf.comment#slurm的作业注释,帮助管理员理解作业的注释

        # Supports only cpu-bind option within srun_args. New options can be added here
        if submitit_conf.get("srun_args", None) is not None:#参数用于控制作业的运行行为
            job_kwargs["slurm_srun_args"] = []
            if submitit_conf.srun_args.get("cpu_bind", None) is not None:
                job_kwargs["slurm_srun_args"].extend(
                    ["--cpu-bind", submitit_conf.srun_args.cpu_bind]
                )

        print("###################### SLURM Config ####################")
        print(job_kwargs)#打印作业参数
        print("##########################################")
        executor.update_parameters(**job_kwargs)#更新作业参数

        main_port = random.randint(#随机生成一个端口号,范围在submitit_conf.port_range[0]和submitit_conf.port_range[1]之间
            submitit_conf.port_range[0], submitit_conf.port_range[1]#用于进程通信
        )
        runner = SubmititRunner(main_port, cfg)
        job = executor.submit(runner)
        print(f"Submitit Job ID: {job.job_id}")
        runner.setup_job_info(job.job_id, rank=0)
    else:#单节点运行
        cfg.launcher.num_nodes = 1
        main_port = random.randint(
            submitit_conf.port_range[0], submitit_conf.port_range[1]
        )
        single_node_runner(cfg, main_port)


if __name__ == "__main__":#如果模块是被直接运行的,则代码块将被运行;如果模块是被导入的,则代码块不被运行。

    initialize_config_module("sam2", version_base="1.2")
    parser = ArgumentParser()#创建ArgumentParser对象
    parser.add_argument(#添加参数
        "-c",
        "--config",
        required=True,#将此参数设置为必须参数
        type=str,
        help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)",
    )
    parser.add_argument(
        "--use-cluster",#是否使用集群
        type=int,
        default=None,
        help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
    )
    parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
    parser.add_argument("--account", type=str, default=None, help="SLURM account")
    parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
    parser.add_argument(
        "--num-gpus", type=int, default=None, help="number of GPUS per node"
    )
    parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")
    args = parser.parse_args()
    args.use_cluster = bool(args.use_cluster) if args.use_cluster is not None else None
    register_omegaconf_resolvers()
    main(args)

<1>with关键字&g_pathmgr.open方法
  • g_pathmgr.open

    • iopath库中的一个方法,用于打开文件
    • 类似于python中的open,但是它可以打开更多的文件
  • with和as

    • 在 Python 中,with 语句和 as 关键字用于上下文管理器(context manager),它们提供了一种简洁的方式来管理资源,如文件、网络连接、锁等。上下文管理器确保资源在使用完毕后被正确地释放或关闭,即使在发生异常的情况下也是如此。

    • 执行流程

      • 打开文件

        with g_pathmgr.open(
        
          os.path.join(cfg.launcher.experiment_log_dir, "config.yaml"), "w"
        
        ) as f:
        

        这行代码使用 g_pathmgr.open方法打开一个文件进行写操作。文件路径是通过 os.path.join方法生成的,文件名为 config.yaml,位于 cfg.launcher.experiment_log_dir目录下。g_pathmgr.open返回一个文件对象,该对象实现了上下文管理器协议。

      • 写入数据

        f.write(OmegaConf.to_yaml(cfg))
        

        这行代码将 cfg配置对象转换为 YAML 格式,并写入到打开的文件(config.yaml)中。

      • 自动关闭文件: 当 with 语句块结束时,无论是否发生异常,文件都会被自动关闭。这是通过文件对象的 __exit__ 方法实现的。

    <2>保存原始配置和解析后的配置
  • 为什么要保存原始配置?

    • 保存原始配置的原因

      1. 保留用户输入:原始配置文件保留了用户最初提供的配置内容,包括所有的注释和格式。这对于审计和调试非常有用。
      2. 可读性:原始配置文件通常更具可读性,因为它保留了用户定义的结构和注释。
      3. 复现性:保留原始配置文件可以确保你能够完全复现最初的配置环境。
    • 保存解析后配置的原因

      1. 解析插值和引用:解析后的配置文件将所有的插值和引用都解析为具体的值,确保配置的完整性和正确性。
      2. 实际使用的配置:解析后的配置文件展示了最终在程序中实际使用的配置内容,这对于调试和验证非常重要。
<2>dict.get()
  • Nonedict.get 方法的默认值参数。dict.get 方法用于从字典中获取指定键的值,如果键不存在,则返回默认值。

  • dict.get 方法的语法如下:

    dict.get(key, default=None)
    
    • key:要获取的键。

    • default:如果键不存在时返回的默认值。默认值是 None

2.sam2_base.py

(1)库的导入
import torch
import torch.distributed # 导入torch.distributed,用于分布式训练
import torch.nn.functional as F # 导入torch.nn.functional,用于激活函数

from torch.nn.init import trunc_normal_ # 导入trunc_normal_,用于初始化权重

from sam2.modeling.sam.mask_decoder import MaskDecoder 
from sam2.modeling.sam.prompt_encoder import PromptEncoder
from sam2.modeling.sam.transformer import TwoWayTransformer # 导入TwoWayTransformer,用于SAM的transformer
from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames # 用于SAM2的一些工具函数

# a large negative value as a placeholder score for missing objects
NO_OBJ_SCORE = -1024.0 # 用于表示没有对象的分数
#在生成mask时,如果某个位置没有检测到对象,可以使用一个很大的负数来表示这个位置的分数
<1>torch.nn.functional
  • 激活函数:

    • F.relu(input): 计算 ReLU 激活函数。
    • F.sigmoid(input): 计算 Sigmoid 激活函数。
    • F.tanh(input): 计算 Tanh 激活函数。
  • 卷积操作:

    • F.conv2d(input, weight): 进行 2D 卷积操作。
    • F.conv3d(input, weight): 进行 3D 卷积操作。
  • 池化操作:

    • F.max_pool2d(input, kernel_size): 进行 2D 最大池化操作。
    • F.avg_pool2d(input, kernel_size): 进行 2D 平均池化操作。
  • 损失函数:

    • F.cross_entropy(input, target): 计算交叉熵损失。
    • F.mse_loss(input, target): 计算均方误差损失。
  • 归一化操作:

    • F.batch_norm(input, running_mean, running_var): 进行批归一化操作
    • F.layer_norm(input, normalized_shape): 进行层归一化操作
<2>2D卷积&3D卷积

2D 卷积和 3D 卷积是卷积神经网络(CNN)中用于提取特征的基本操作,它们的主要区别在于输入数据的维度和应用场景。

  • 2D 卷积

    • 定义:2D 卷积是对二维数据(如图像)进行卷积操作,用于提取空间特征。

      • 输入数据
        • 输入张量的形状为 (batch_size, in_channels, height, width),其中:
          • batch_size:批大小。
          • in_channels:输入通道数(如 RGB 图像的通道数为 3)。
          • heightwidth:图像的高度和宽度。
      • 卷积核
        • 卷积核的形状为 (out_channels, in_channels, kernel_height, kernel_width),其中:
          • out_channels:输出通道数(卷积核的数量)。
          • in_channels:输入通道数。
          • kernel_heightkernel_width:卷积核的高度和宽度。
      • 操作过程
        • 卷积核在输入图像上滑动,计算每个位置的加权和,生成一个特征图。
        • 输出张量的形状为 (batch_size, out_channels, output_height, output_width)
    • 示例代码

import torch
import torch.nn.functional as F

# 输入数据 (batch_size=1, in_channels=3, height=64, width=64)
input = torch.randn(1, 3, 64, 64)

# 卷积核 (out_channels=16, in_channels=3, kernel_height=3, kernel_width=3)
weight = torch.randn(16, 3, 3, 3)

# 2D 卷积
output = F.conv2d(input, weight)
print(output.shape)  # 输出形状: (1, 16, 62, 62)
  • 3D 卷积
    • 定义:3D 卷积是对三维数据(如视频或 3D 体数据)进行卷积操作,用于提取空间和时间特征。

      • 输入数据
        • 输入张量的形状为 (batch_size, in_channels, depth, height, width),其中:
          • batch_size:批大小。
          • in_channels:输入通道数。
          • depth:数据的深度(如视频的帧数或 3D 数据的深度)。
          • heightwidth:数据的高度和宽度。
      • 卷积核
        • 卷积核的形状为 (out_channels, in_channels, kernel_depth, kernel_height, kernel_width),其中:
          • out_channels:输出通道数。
          • in_channels:输入通道数。
          • kernel_depth:卷积核的深度,决定了卷积核在输入数据的深度方向上覆盖的范围。
          • kernel_heightkernel_width:卷积核的高度和宽度。
      • 操作过程
        • 卷积核在输入数据上滑动,计算每个位置的加权和,生成一个特征图。
        • 输出张量的形状为 (batch_size, out_channels, output_depth, output_height, output_width)
    • 示例代码

import torch
import torch.nn.functional as F

# 输入数据 (batch_size=1, in_channels=3, depth=10, height=64, width=64)
input = torch.randn(1, 3, 10, 64, 64)

# 卷积核 (out_channels=16, in_channels=3, kernel_depth=3, kernel_height=3, kernel_width=3)
weight = torch.randn(16, 3, 3, 3, 3)

# 3D 卷积
output = F.conv3d(input, weight)
print(output.shape)  # 输出形状: (1, 16, 8, 62, 62)
<3>torch.randn
  • torch.randn 是 PyTorch 中用于生成**服从标准正态分布(均值为 0,标准差为 1)**的随机张量的函数。它的用法非常灵活,可以根据需要生成不同形状的张量。以下是详细说明和示例:

  • ** 基本语法**

torch.randn(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
  • 参数说明

    • *size:张量的形状(可以是多个整数,如 (3, 4)3, 4)。
    • out(可选):输出张量。
    • dtype(可选):张量的数据类型(如 torch.float32)。
    • layout(可选):张量的内存布局(默认是 torch.strided)。
    • device(可选):张量所在的设备(如 'cpu''cuda')。
    • requires_grad(可选):是否需要对张量求梯度(默认是 False)。
  • 示例

x = torch.randn(3, 4)  # 生成一个 3 行 4 列的矩阵
print(x)
'''
输出
tensor([[ 0.1234, -0.5678,  1.2345, -0.4321],
        [ 0.9876, -0.1234,  0.5678, -1.2345],
        [-0.4321,  0.9876, -0.5678,  1.2345]])
'''
<4>torch.nn.init
  • torch.nn.init 模块包含了许多用于初始化神经网络权重的函数。这些函数可以帮助我们以不同的方式初始化模型的参数,从而影响模型的训练效果。

  • 主要功能

    1. 常用初始化方法
      • init.xavier_uniform_(tensor): 使用 Xavier 均匀分布初始化张量。
      • init.xavier_normal_(tensor)使用 Xavier 正态分布初始化张量。
      • init.kaiming_uniform_(tensor): 使用 Kaiming 均匀分布初始化张量。
      • init.kaiming_normal_(tensor): 使用 Kaiming 正态分布初始化张量。
    2. 常数初始化
      • init.constant_(tensor, val): 将张量初始化为常数 val
      • init.zeros_(tensor): 将张量初始化为全零。
      • init.ones_(tensor): 将张量初始化为全一。
    3. 随机初始化
      • init.uniform_(tensor, a, b): 使用均匀分布 [a, b] 初始化张量。
      • init.normal_(tensor, mean, std): 使用正态分布 N(mean, std) 初始化张量。
  • 示例

import torch
import torch.nn as nn
import torch.nn.init as init

# 定义一个线性层
linear = nn.Linear(3, 3)

# 使用 Xavier 均匀分布初始化权重
init.xavier_uniform_(linear.weight)
print(linear.weight)

'''
可能输出
tensor([[ 0.1234, -0.5678,  0.4321],
        [ 0.9876, -0.1234, -0.5678],
        [-0.4321,  0.9876,  0.1234]], requires_grad=True)
'''
(2)SAM2Base类定义–没看完

实在是太长了,看不完:(

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.distributed # 导入torch.distributed,用于分布式训练
import torch.nn.functional as F # 导入torch.nn.functional,用于激活函数

from torch.nn.init import trunc_normal_ # 导入trunc_normal_,用于初始化权重

from sam2.modeling.sam.mask_decoder import MaskDecoder 
from sam2.modeling.sam.prompt_encoder import PromptEncoder
from sam2.modeling.sam.transformer import TwoWayTransformer # 导入TwoWayTransformer,用于SAM的transformer
from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames # 用于SAM2的一些工具函数

# a large negative value as a placeholder score for missing objects
NO_OBJ_SCORE = -1024.0 # 用于表示没有对象的分数
#在生成mask时,如果某个位置没有检测到对象,可以使用一个很大的负数来表示这个位置的分数

class SAM2Base(torch.nn.Module):
    def __init__(
        self,
        image_encoder,
        memory_attention,
        memory_encoder,
        num_maskmem=7,  #默认设置:一张input,6张previous frames
        image_size=512, #输入图像的大小
        backbone_stride=16,  # 16 for ResNet50-FPN, 32 for ResNet50,图像编码器的步长
        sigmoid_scale_for_mem_enc=1.0,  # sigmoid函数的缩放因子
        sigmoid_bias_for_mem_enc=0.0,   # sigmoid函数的偏置

        
        binarize_mask_from_pts_for_mem_enc=False,# 在评估过程中,是否对与点击交互的帧上的sigmoid mask logits进行二值化
        #二值化是指将一个连续的值转换为一个二进制值,通常是0或1--大于某个阈值的值转换为1,小于某个阈值的值转换为0
        #这里是指将sigmoid mask logits转换为二进制值,即大于0.5的值转换为1,小于0.5的值转换为0

        use_mask_input_as_output_without_sam=False,  # 在带有mask输入的帧上,是否直接输出输入mask,而不使用SAM提示编码器+mask解码器

        
        max_cond_frames_in_attn=-1,# 参与记忆注意力的最大条件帧数(-1表示没有限制;如果有的话,我们只在跟踪每帧时在编码器中跨越注意到最接近时间的`max_cond_frames_in_attn`条件帧)。
        #条件帧指的是历史帧或者其他相关帧

        directly_add_no_mem_embed=False,# 在第一帧上,不直接将无记忆嵌入添加到图像特征中,而是使用transformer编码器

        # whether to use high-resolution feature maps in the SAM mask decoder
        use_high_res_features_in_sam=False,

        multimask_output_in_sam=False,# 若在初始条件帧上进行第一次点击,是否输出多个(3个)mask
        # 初始条件帧的作用是在训练过程中,用于初始化模型的帧,即模型的第一帧输入

      
        multimask_min_pt_num=1,# 在SAM中,多mask输出的最小点击数
        multimask_max_pt_num=1,# 在SAM中,多mask输出的最大点击数
        #意思是用户只能点击一次

       
        multimask_output_for_tracking=False,# 是否在跟踪中也使用多mask输出(不仅仅是在初始条件帧上的第一次点击;仅当`multimask_output_in_sam=True`时才相关)
        # 这个参数的作用是在跟踪过程中,是否也输出多个mask
        # 如果为True,则在跟踪过程中也会输出多个mask,如果为False,则只输出一个mask


        use_multimask_token_for_obj_ptr: bool = False,
        # 作用是在SAM的mask解码器中,是否使用多mask token来预测对象指针
        # mask_token是指在SAM的mask解码器中,用于预测对象指针的token,是一个嵌入式向量


        iou_prediction_use_sigmoid=False,# 是否使用sigmoid函数将iou预测限制在[0-1]之间


        memory_temporal_stride_for_eval=1,
        # 在评估过程中,记忆库的时间步长(即XMem和Cutie中的`r`参数;XMem和Cutie使用r=5)。


        non_overlap_masks_for_mem_enc=False,
        # 在评估过程中,是否在记忆编码器中对对象mask应用非重叠约束(以避免/减轻叠加mask)

        
        use_obj_ptrs_in_encoder=False,
        # 在编码器中是否跨帧交叉注意力到其他帧的对象指针(仅在`use_obj_ptrs_in_encoder=True`时相关)


        # 在编码器中最大的对象指针数
        max_obj_ptrs_in_encoder=16,


        
        add_tpos_enc_to_obj_ptrs=True,
        # 在编码器中是否为对象指针添加时间位置编码(仅在`use_obj_ptrs_in_encoder=True`时相关)


        proj_tpos_enc_in_obj_ptrs=False,
        # 在对象指针中是否为时间位置编码添加额外的线性投影层,以避免潜在的干扰空间位置编码(仅在`use_obj_ptrs_in_encoder=True`和`add_tpos_enc_to_obj_ptrs=True`时相关)
        
        use_signed_tpos_enc_to_obj_ptrs=False,
        # 在对象指针中是否使用有符号的时间位置编码(仅在`use_obj_ptrs_in_encoder=True`和`add_tpos_enc_to_obj_ptrs=True`时相关)

        only_obj_ptrs_in_the_past_for_eval=False,
        # 在评估过程中,是否只使用过去的对象指针(仅在`use_obj_ptrs_in_encoder=True`时相关)
        

        pred_obj_scores: bool = False,
        # 不使用MLP来预测对象分数,而是maskdecoder的输出来确定对象分数
        # Whether to use an MLP to predict object scores
        pred_obj_scores_mlp: bool = False,


        # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
        # Whether to have a fixed no obj pointer when there is no object present
        # or to use it as an additive embedding with obj_ptr produced by decoder
        fixed_no_obj_ptr: bool = False,# 是否在没有对象时使用固定的无对象指针
        soft_no_obj_ptr: bool = False,# 是否在没有对象时使用软无对象指针,即在没有对象时,使用无对象指针的混合,希望使恢复更容易,如果有错误的话,并减少错误的积累
   
        use_mlp_for_obj_ptr_proj: bool = False,# 是否在对象指针投影中使用MLP
        # add no obj embedding to spatial frames
        no_obj_embed_spatial: bool = False,# 是否将无对象嵌入添加到空间帧中
       
        sam_mask_decoder_extra_args=None, # 用于构建SAM mask解码器的额外参数;如果不是None,则它应该是一个传递给`MaskDecoder`类的kwargs字典。
        compile_image_encoder: bool = False, # 是否编译图像编码器的forward函数
    ):
        super().__init__()

        # 初始化
        # Part 1: the image backbone
        self.image_encoder = image_encoder
        # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
        self.use_high_res_features_in_sam = use_high_res_features_in_sam
        self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 # 使用高分辨率特征的级别
        self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
        self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder


        if use_obj_ptrs_in_encoder:
            '''
                这段代码的作用是​在编码器中使用对象指针(object pointers),
                并通过 ​卷积层 将​掩码提示(mask prompt)下采样到与​低分辨率 SAM 掩码 logits 相同的 stride(步幅),
                并将其比例从 [0, 1] 转换为​SAM logit 比例,以便将其输入到​SAM 掩码解码器 中生成指针。
            
            '''
            self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)

        self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs # 是否为对象指针添加时间位置编码


        if proj_tpos_enc_in_obj_ptrs:
            assert add_tpos_enc_to_obj_ptrs  # these options need to be used together
        self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
        self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
        self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval

        # Part 2: memory attention to condition current frame's visual features
        # 第2部分:记忆注意力以调整当前帧的视觉特征

        self.memory_attention = memory_attention
        self.hidden_dim = image_encoder.neck.d_model# 隐藏维度,即图像编码器的维度
        '''
        隐藏维度(hidden dimension)通常指的是隐藏层的特征向量的维度。
        隐藏层是神经网络中介于输入层和输出层之间的层,它们负责对输入数据进行特征提取和变换。
        隐藏维度决定了隐藏层中每个神经元的输出特征向量的大小。
        '''
        # Part 3: memory encoder for the previous frame's outputs
        self.memory_encoder = memory_encoder
        self.mem_dim = self.hidden_dim

        '''
        这段代码的作用是 ​检查 self.memory_encoder 是否存在 out_proj 属性,并且 out_proj 是否包含 weight 属性。
        如果条件成立,则获取 out_proj.weight 的 ​第一个维度的大小,并将其赋值给 self.mem_dim。
        
        '''
        if hasattr(self.memory_encoder, "out_proj") and hasattr(# 如果记忆编码器有out_proj属性和weight属性
            #out_proj是记忆编码器的输出投影层,weight是权重
            self.memory_encoder.out_proj, "weight"
        ):
            # if there is compression of memories along channel dim
            self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]




        '''
        ​self.num_maskmem = num_maskmem:将 num_maskmem 赋值给 self.num_maskmem,表示 ​记忆库中 mask 的数量。
        num_maskmem 是一个整数,表示记忆库中存储的 mask 数量。
        ​self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)):
        初始化一个​可学习的参数,用于表示​时间位置编码。
        torch.zeros(num_maskmem, 1, 1, self.mem_dim) 创建一个形状为 [num_maskmem, 1, 1, self.mem_dim] 的全零张量。
        torch.nn.Parameter 将其转换为可学习的参数。
        '''
        self.num_maskmem = num_maskmem  # 记忆库mask的数量
        # Temporal encoding of the memories
        self.maskmem_tpos_enc = torch.nn.Parameter(
            torch.zeros(num_maskmem, 1, 1, self.mem_dim)
            #第一个参数表示有多少个mask
            #第二三个参数是占位符
            #第四个参数是特征的维度
        )


        trunc_normal_(self.maskmem_tpos_enc, std=0.02)
        # a single token to indicate no memory embedding from previous frames
        self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
        self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
        trunc_normal_(self.no_mem_embed, std=0.02)
        trunc_normal_(self.no_mem_pos_enc, std=0.02)
        self.directly_add_no_mem_embed = directly_add_no_mem_embed
        # Apply sigmoid to the output raw mask logits (to turn them from
        # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
        self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
        self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
        self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
        self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
        self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
        # On frames with mask input, whether to directly output the input mask without
        # using a SAM prompt encoder + mask decoder
        self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
        self.multimask_output_in_sam = multimask_output_in_sam
        self.multimask_min_pt_num = multimask_min_pt_num
        self.multimask_max_pt_num = multimask_max_pt_num
        self.multimask_output_for_tracking = multimask_output_for_tracking
        self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
        self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid

        # Part 4: SAM-style prompt encoder (for both mask and point inputs)
        # and SAM-style mask decoder for the final mask output
        self.image_size = image_size
        self.backbone_stride = backbone_stride
        self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
        self.pred_obj_scores = pred_obj_scores
        self.pred_obj_scores_mlp = pred_obj_scores_mlp
        self.fixed_no_obj_ptr = fixed_no_obj_ptr
        self.soft_no_obj_ptr = soft_no_obj_ptr
        if self.fixed_no_obj_ptr:
            assert self.pred_obj_scores
            assert self.use_obj_ptrs_in_encoder
        if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
            self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
            trunc_normal_(self.no_obj_ptr, std=0.02)
        self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
        self.no_obj_embed_spatial = None
        if no_obj_embed_spatial:
            self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
            trunc_normal_(self.no_obj_embed_spatial, std=0.02)

        self._build_sam_heads()
        self.max_cond_frames_in_attn = max_cond_frames_in_attn

        # Model compilation
        if compile_image_encoder:
            # Compile the forward function (not the full module) to allow loading checkpoints.
            print(
                "Image encoder compilation is enabled. First forward pass will be slow."
            )
            self.image_encoder.forward = torch.compile(
                self.image_encoder.forward,
                mode="max-autotune",
                fullgraph=True,
                dynamic=False,
            )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, *args, **kwargs):
        raise NotImplementedError(
            "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
            "See notebooks/video_predictor_example.ipynb for an inference example."
        )

    def _build_sam_heads(self):
        """Build SAM-style prompt encoder and mask decoder."""
        self.sam_prompt_embed_dim = self.hidden_dim
        self.sam_image_embedding_size = self.image_size // self.backbone_stride

        # build PromptEncoder and MaskDecoder from SAM
        # (their hyperparameters like `mask_in_chans=16` are from SAM code)
        self.sam_prompt_encoder = PromptEncoder(
            embed_dim=self.sam_prompt_embed_dim,
            image_embedding_size=(
                self.sam_image_embedding_size,
                self.sam_image_embedding_size,
            ),
            input_image_size=(self.image_size, self.image_size),
            mask_in_chans=16,
        )
        self.sam_mask_decoder = MaskDecoder(
            num_multimask_outputs=3,
            transformer=TwoWayTransformer(
                depth=2,
                embedding_dim=self.sam_prompt_embed_dim,
                mlp_dim=2048,
                num_heads=8,
            ),
            transformer_dim=self.sam_prompt_embed_dim,
            iou_head_depth=3,
            iou_head_hidden_dim=256,
            use_high_res_features=self.use_high_res_features_in_sam,
            iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
            pred_obj_scores=self.pred_obj_scores,
            pred_obj_scores_mlp=self.pred_obj_scores_mlp,
            use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
            **(self.sam_mask_decoder_extra_args or {}),
        )
        if self.use_obj_ptrs_in_encoder:
            # a linear projection on SAM output tokens to turn them into object pointers
            self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
            if self.use_mlp_for_obj_ptr_proj:
                self.obj_ptr_proj = MLP(
                    self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
                )
        else:
            self.obj_ptr_proj = torch.nn.Identity()
        if self.proj_tpos_enc_in_obj_ptrs:
            # a linear projection on temporal positional encoding in object pointers to
            # avoid potential interference with spatial positional encoding
            self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
        else:
            self.obj_ptr_tpos_proj = torch.nn.Identity()

    def _forward_sam_heads(
        self,
        backbone_features,
        point_inputs=None,
        mask_inputs=None,
        high_res_features=None,
        multimask_output=False,
    ):
        """
        Forward SAM prompt encoders and mask heads.

        Inputs:
        - backbone_features: image features of [B, C, H, W] shape
        - point_inputs: a dictionary with "point_coords" and "point_labels", where
          1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
             absolute pixel-unit coordinate in (x, y) format of the P input points
          2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
             positive clicks, 0 means negative clicks, and -1 means padding
        - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
          same spatial size as the image.
        - high_res_features: either 1) None or 2) or a list of length 2 containing
          two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
          which will be used as high-resolution feature maps for SAM decoder.
        - multimask_output: if it's True, we output 3 candidate masks and their 3
          corresponding IoU estimates, and if it's False, we output only 1 mask and
          its corresponding IoU estimate.

        Outputs:
        - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
          `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
          output mask logits (before sigmoid) for the low-resolution masks, with 4x
          the resolution (1/4 stride) of the input backbone_features.
        - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
          if `multimask_output=True` and M = 1 if `multimask_output=False`),
          upsampled from the low-resolution masks, with shape size as the image
          (stride is 1 pixel).
        - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
          if `multimask_output=False`), the estimated IoU of each output mask.
        - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
          If `multimask_output=True`, it's the mask with the highest IoU estimate.
          If `multimask_output=False`, it's the same as `low_res_multimasks`.
        - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
          If `multimask_output=True`, it's the mask with the highest IoU estimate.
          If `multimask_output=False`, it's the same as `high_res_multimasks`.
        - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
          based on the output token from the SAM mask decoder.
        """
        B = backbone_features.size(0)
        device = backbone_features.device
        assert backbone_features.size(1) == self.sam_prompt_embed_dim
        assert backbone_features.size(2) == self.sam_image_embedding_size
        assert backbone_features.size(3) == self.sam_image_embedding_size

        # a) Handle point prompts
        if point_inputs is not None:
            sam_point_coords = point_inputs["point_coords"]
            sam_point_labels = point_inputs["point_labels"]
            assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
        else:
            # If no points are provide, pad with an empty point (with label -1)
            sam_point_coords = torch.zeros(B, 1, 2, device=device)
            sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)

        # b) Handle mask prompts
        if mask_inputs is not None:
            # If mask_inputs is provided, downsize it into low-res mask input if needed
            # and feed it as a dense mask prompt into the SAM mask encoder
            assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
            if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
                sam_mask_prompt = F.interpolate(
                    mask_inputs.float(),
                    size=self.sam_prompt_encoder.mask_input_size,
                    align_corners=False,
                    mode="bilinear",
                    antialias=True,  # use antialias for downsampling
                )
            else:
                sam_mask_prompt = mask_inputs
        else:
            # Otherwise, simply feed None (and SAM's prompt encoder will add
            # a learned `no_mask_embed` to indicate no mask input in this case).
            sam_mask_prompt = None

        sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
            points=(sam_point_coords, sam_point_labels),
            boxes=None,
            masks=sam_mask_prompt,
        )
        (
            low_res_multimasks,
            ious,
            sam_output_tokens,
            object_score_logits,
        ) = self.sam_mask_decoder(
            image_embeddings=backbone_features,
            image_pe=self.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
            repeat_image=False,  # the image is already batched
            high_res_features=high_res_features,
        )
        if self.pred_obj_scores:
            is_obj_appearing = object_score_logits > 0

            # Mask used for spatial memories is always a *hard* choice between obj and no obj,
            # consistent with the actual mask prediction
            low_res_multimasks = torch.where(
                is_obj_appearing[:, None, None],
                low_res_multimasks,
                NO_OBJ_SCORE,
            )

        # convert masks from possibly bfloat16 (or float16) to float32
        # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
        low_res_multimasks = low_res_multimasks.float()
        high_res_multimasks = F.interpolate(
            low_res_multimasks,
            size=(self.image_size, self.image_size),
            mode="bilinear",
            align_corners=False,
        )

        sam_output_token = sam_output_tokens[:, 0]
        if multimask_output:
            # take the best mask prediction (with the highest IoU estimation)
            best_iou_inds = torch.argmax(ious, dim=-1)
            batch_inds = torch.arange(B, device=device)
            low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
            high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
            if sam_output_tokens.size(1) > 1:
                sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
        else:
            low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks

        # Extract object pointer from the SAM output token (with occlusion handling)
        obj_ptr = self.obj_ptr_proj(sam_output_token)
        if self.pred_obj_scores:
            # Allow *soft* no obj ptr, unlike for masks
            if self.soft_no_obj_ptr:
                lambda_is_obj_appearing = object_score_logits.sigmoid()
            else:
                lambda_is_obj_appearing = is_obj_appearing.float()

            if self.fixed_no_obj_ptr:
                obj_ptr = lambda_is_obj_appearing * obj_ptr
            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr

        return (
            low_res_multimasks,
            high_res_multimasks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
        )

    def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
        """
        Directly turn binary `mask_inputs` into a output mask logits without using SAM.
        (same input and output shapes as in _forward_sam_heads above).
        """
        # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
        out_scale, out_bias = 20.0, -10.0  # sigmoid(-10.0)=4.5398e-05
        mask_inputs_float = mask_inputs.float()
        high_res_masks = mask_inputs_float * out_scale + out_bias
        low_res_masks = F.interpolate(
            high_res_masks,
            size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
            align_corners=False,
            mode="bilinear",
            antialias=True,  # use antialias for downsampling
        )
        # a dummy IoU prediction of all 1's under mask input
        ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
        if not self.use_obj_ptrs_in_encoder:
            # all zeros as a dummy object pointer (of shape [B, C])
            obj_ptr = torch.zeros(
                mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
            )
        else:
            # produce an object pointer using the SAM decoder from the mask input
            _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
                backbone_features=backbone_features,
                mask_inputs=self.mask_downsample(mask_inputs_float),
                high_res_features=high_res_features,
            )
        # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
        # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
        # on the object_scores from the SAM decoder.
        is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
        is_obj_appearing = is_obj_appearing[..., None]
        lambda_is_obj_appearing = is_obj_appearing.float()
        object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
        if self.pred_obj_scores:
            if self.fixed_no_obj_ptr:
                obj_ptr = lambda_is_obj_appearing * obj_ptr
            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr

        return (
            low_res_masks,
            high_res_masks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
        )

    def forward_image(self, img_batch: torch.Tensor):
        """Get the image feature on the input batch."""
        backbone_out = self.image_encoder(img_batch)
        if self.use_high_res_features_in_sam:
            # precompute projected level 0 and level 1 features in SAM decoder
            # to avoid running it again on every SAM click
            backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
                backbone_out["backbone_fpn"][0]
            )
            backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
                backbone_out["backbone_fpn"][1]
            )
        return backbone_out

    def _prepare_backbone_features(self, backbone_out):
        """Prepare and flatten visual features."""
        backbone_out = backbone_out.copy()
        assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
        assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels

        feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]

        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
        # flatten NxCxHxW to HWxNxC
        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
        vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]

        return backbone_out, vision_feats, vision_pos_embeds, feat_sizes

    def _prepare_memory_conditioned_features(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        feat_sizes,
        output_dict,
        num_frames,
        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
    ):
        """Fuse the current frame's visual feature map with previous memory."""
        B = current_vision_feats[-1].size(1)  # batch size on this frame
        C = self.hidden_dim
        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
        device = current_vision_feats[-1].device
        # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
        # In this case, we skip the fusion with any memory.
        if self.num_maskmem == 0:  # Disable memory and skip fusion
            pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
            return pix_feat

        num_obj_ptr_tokens = 0
        tpos_sign_mul = -1 if track_in_reverse else 1
        # Step 1: condition the visual features of the current frame on previous memories
        if not is_init_cond_frame:
            # Retrieve the memories encoded with the maskmem backbone
            to_cat_memory, to_cat_memory_pos_embed = [], []
            # Add conditioning frames's output first (all cond frames have t_pos=0 for
            # when getting temporal positional embedding below)
            assert len(output_dict["cond_frame_outputs"]) > 0
            # Select a maximum number of temporally closest cond frames for cross attention
            cond_outputs = output_dict["cond_frame_outputs"]
            selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
                frame_idx, cond_outputs, self.max_cond_frames_in_attn
            )
            t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
            # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
            # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
            # We also allow taking the memory frame non-consecutively (with stride>1), in which case
            # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
            stride = 1 if self.training else self.memory_temporal_stride_for_eval
            for t_pos in range(1, self.num_maskmem):
                t_rel = self.num_maskmem - t_pos  # how many frames before current frame
                if t_rel == 1:
                    # for t_rel == 1, we take the last frame (regardless of r)
                    if not track_in_reverse:
                        # the frame immediately before this frame (i.e. frame_idx - 1)
                        prev_frame_idx = frame_idx - t_rel
                    else:
                        # the frame immediately after this frame (i.e. frame_idx + 1)
                        prev_frame_idx = frame_idx + t_rel
                else:
                    # for t_rel >= 2, we take the memory frame from every r-th frames
                    if not track_in_reverse:
                        # first find the nearest frame among every r-th frames before this frame
                        # for r=1, this would be (frame_idx - 2)
                        prev_frame_idx = ((frame_idx - 2) // stride) * stride
                        # then seek further among every r-th frames
                        prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
                    else:
                        # first find the nearest frame among every r-th frames after this frame
                        # for r=1, this would be (frame_idx + 2)
                        prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
                        # then seek further among every r-th frames
                        prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
                out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
                if out is None:
                    # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
                    # frames, we still attend to it as if it's a non-conditioning frame.
                    out = unselected_cond_outputs.get(prev_frame_idx, None)
                t_pos_and_prevs.append((t_pos, out))

            for t_pos, prev in t_pos_and_prevs:
                if prev is None:
                    continue  # skip padding frames
                # "maskmem_features" might have been offloaded to CPU in demo use cases,
                # so we load it back to GPU (it's a no-op if it's already on GPU).
                feats = prev["maskmem_features"].to(device, non_blocking=True)
                to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
                # Spatial positional encoding (it might have been offloaded to CPU in eval)
                maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
                maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
                # Temporal positional encoding
                maskmem_enc = (
                    maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
                )
                to_cat_memory_pos_embed.append(maskmem_enc)

            # Construct the list of past object pointers
            if self.use_obj_ptrs_in_encoder:
                max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
                # First add those object pointers from selected conditioning frames
                # (optionally, only include object pointers in the past during evaluation)
                if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
                    ptr_cond_outputs = {
                        t: out
                        for t, out in selected_cond_outputs.items()
                        if (t >= frame_idx if track_in_reverse else t <= frame_idx)
                    }
                else:
                    ptr_cond_outputs = selected_cond_outputs
                pos_and_ptrs = [
                    # Temporal pos encoding contains how far away each pointer is from current frame
                    (
                        (
                            (frame_idx - t) * tpos_sign_mul
                            if self.use_signed_tpos_enc_to_obj_ptrs
                            else abs(frame_idx - t)
                        ),
                        out["obj_ptr"],
                    )
                    for t, out in ptr_cond_outputs.items()
                ]
                # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
                for t_diff in range(1, max_obj_ptrs_in_encoder):
                    t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
                    if t < 0 or (num_frames is not None and t >= num_frames):
                        break
                    out = output_dict["non_cond_frame_outputs"].get(
                        t, unselected_cond_outputs.get(t, None)
                    )
                    if out is not None:
                        pos_and_ptrs.append((t_diff, out["obj_ptr"]))
                # If we have at least one object pointer, add them to the across attention
                if len(pos_and_ptrs) > 0:
                    pos_list, ptrs_list = zip(*pos_and_ptrs)
                    # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
                    obj_ptrs = torch.stack(ptrs_list, dim=0)
                    # a temporal positional embedding based on how far each object pointer is from
                    # the current frame (sine embedding normalized by the max pointer num).
                    if self.add_tpos_enc_to_obj_ptrs:
                        t_diff_max = max_obj_ptrs_in_encoder - 1
                        tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
                        obj_pos = torch.tensor(pos_list).to(
                            device=device, non_blocking=True
                        )
                        obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
                        obj_pos = self.obj_ptr_tpos_proj(obj_pos)
                        obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
                    else:
                        obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
                    if self.mem_dim < C:
                        # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
                        obj_ptrs = obj_ptrs.reshape(
                            -1, B, C // self.mem_dim, self.mem_dim
                        )
                        obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
                        obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
                    to_cat_memory.append(obj_ptrs)
                    to_cat_memory_pos_embed.append(obj_pos)
                    num_obj_ptr_tokens = obj_ptrs.shape[0]
                else:
                    num_obj_ptr_tokens = 0
        else:
            # for initial conditioning frames, encode them without using any previous memory
            if self.directly_add_no_mem_embed:
                # directly add no-mem embedding (instead of using the transformer encoder)
                pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
                pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
                return pix_feat_with_mem

            # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
            to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
            to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]

        # Step 2: Concatenate the memories and forward through the transformer encoder
        memory = torch.cat(to_cat_memory, dim=0)
        memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)

        pix_feat_with_mem = self.memory_attention(
            curr=current_vision_feats,
            curr_pos=current_vision_pos_embeds,
            memory=memory,
            memory_pos=memory_pos_embed,
            num_obj_ptr_tokens=num_obj_ptr_tokens,
        )
        # reshape the output (HW)BC => BCHW
        pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
        return pix_feat_with_mem

    def _encode_new_memory(
        self,
        current_vision_feats,
        feat_sizes,
        pred_masks_high_res,
        object_score_logits,
        is_mask_from_pts,
    ):
        """Encode the current image and its prediction into a memory feature."""
        B = current_vision_feats[-1].size(1)  # batch size on this frame
        C = self.hidden_dim
        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
        # top-level feature, (HW)BC => BCHW
        pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
        if self.non_overlap_masks_for_mem_enc and not self.training:
            # optionally, apply non-overlapping constraints to the masks (it's applied
            # in the batch dimension and should only be used during eval, where all
            # the objects come from the same video under batch size 1).
            pred_masks_high_res = self._apply_non_overlapping_constraints(
                pred_masks_high_res
            )
        # scale the raw mask logits with a temperature before applying sigmoid
        binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
        if binarize and not self.training:
            mask_for_mem = (pred_masks_high_res > 0).float()
        else:
            # apply sigmoid on the raw mask logits to turn them into range (0, 1)
            mask_for_mem = torch.sigmoid(pred_masks_high_res)
        # apply scale and bias terms to the sigmoid probabilities
        if self.sigmoid_scale_for_mem_enc != 1.0:
            mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
        if self.sigmoid_bias_for_mem_enc != 0.0:
            mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
        maskmem_out = self.memory_encoder(
            pix_feat, mask_for_mem, skip_mask_sigmoid=True  # sigmoid already applied
        )
        maskmem_features = maskmem_out["vision_features"]
        maskmem_pos_enc = maskmem_out["vision_pos_enc"]
        # add a no-object embedding to the spatial memory to indicate that the frame
        # is predicted to be occluded (i.e. no object is appearing in the frame)
        if self.no_obj_embed_spatial is not None:
            is_obj_appearing = (object_score_logits > 0).float()
            maskmem_features += (
                1 - is_obj_appearing[..., None, None]
            ) * self.no_obj_embed_spatial[..., None, None].expand(
                *maskmem_features.shape
            )

        return maskmem_features, maskmem_pos_enc

    def _track_step(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse,
        prev_sam_mask_logits,
    ):
        current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
        # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
        if len(current_vision_feats) > 1:
            high_res_features = [
                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
                for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
            ]
        else:
            high_res_features = None
        if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
            # When use_mask_input_as_output_without_sam=True, we directly output the mask input
            # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
            pix_feat = current_vision_feats[-1].permute(1, 2, 0)
            pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
            sam_outputs = self._use_mask_as_output(
                pix_feat, high_res_features, mask_inputs
            )
        else:
            # fused the visual feature with previous memory features in the memory bank
            pix_feat = self._prepare_memory_conditioned_features(
                frame_idx=frame_idx,
                is_init_cond_frame=is_init_cond_frame,
                current_vision_feats=current_vision_feats[-1:],
                current_vision_pos_embeds=current_vision_pos_embeds[-1:],
                feat_sizes=feat_sizes[-1:],
                output_dict=output_dict,
                num_frames=num_frames,
                track_in_reverse=track_in_reverse,
            )
            # apply SAM-style segmentation head
            # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
            # e.g. in demo where such logits come from earlier interaction instead of correction sampling
            # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
            if prev_sam_mask_logits is not None:
                assert point_inputs is not None and mask_inputs is None
                mask_inputs = prev_sam_mask_logits
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
            sam_outputs = self._forward_sam_heads(
                backbone_features=pix_feat,
                point_inputs=point_inputs,
                mask_inputs=mask_inputs,
                high_res_features=high_res_features,
                multimask_output=multimask_output,
            )

        return current_out, sam_outputs, high_res_features, pix_feat

    def _encode_memory_in_output(
        self,
        current_vision_feats,
        feat_sizes,
        point_inputs,
        run_mem_encoder,
        high_res_masks,
        object_score_logits,
        current_out,
    ):
        if run_mem_encoder and self.num_maskmem > 0:
            high_res_masks_for_mem_enc = high_res_masks
            maskmem_features, maskmem_pos_enc = self._encode_new_memory(
                current_vision_feats=current_vision_feats,
                feat_sizes=feat_sizes,
                pred_masks_high_res=high_res_masks_for_mem_enc,
                object_score_logits=object_score_logits,
                is_mask_from_pts=(point_inputs is not None),
            )
            current_out["maskmem_features"] = maskmem_features
            current_out["maskmem_pos_enc"] = maskmem_pos_enc
        else:
            current_out["maskmem_features"] = None
            current_out["maskmem_pos_enc"] = None

    def track_step(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
        # Whether to run the memory encoder on the predicted masks. Sometimes we might want
        # to skip the memory encoder with `run_mem_encoder=False`. For example,
        # in demo we might call `track_step` multiple times for each user click,
        # and only encode the memory when the user finalizes their clicks. And in ablation
        # settings like SAM training on static images, we don't need the memory encoder.
        run_mem_encoder=True,
        # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
        prev_sam_mask_logits=None,
    ):
        current_out, sam_outputs, _, _ = self._track_step(
            frame_idx,
            is_init_cond_frame,
            current_vision_feats,
            current_vision_pos_embeds,
            feat_sizes,
            point_inputs,
            mask_inputs,
            output_dict,
            num_frames,
            track_in_reverse,
            prev_sam_mask_logits,
        )

        (
            _,
            _,
            _,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
        ) = sam_outputs

        current_out["pred_masks"] = low_res_masks
        current_out["pred_masks_high_res"] = high_res_masks
        current_out["obj_ptr"] = obj_ptr
        if not self.training:
            # Only add this in inference (to avoid unused param in activation checkpointing;
            # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
            current_out["object_score_logits"] = object_score_logits

        # Finally run the memory encoder on the predicted mask to encode
        # it into a new memory feature (that can be used in future frames)
        self._encode_memory_in_output(
            current_vision_feats,
            feat_sizes,
            point_inputs,
            run_mem_encoder,
            high_res_masks,
            object_score_logits,
            current_out,
        )

        return current_out

    def _use_multimask(self, is_init_cond_frame, point_inputs):
        """Whether to use multimask output in the SAM head."""
        num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
        multimask_output = (
            self.multimask_output_in_sam
            and (is_init_cond_frame or self.multimask_output_for_tracking)
            and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
        )
        return multimask_output

    def _apply_non_overlapping_constraints(self, pred_masks):
        """
        Apply non-overlapping constraints to the object scores in pred_masks. Here we
        keep only the highest scoring object at each spatial location in pred_masks.
        """
        batch_size = pred_masks.size(0)
        if batch_size == 1:
            return pred_masks

        device = pred_masks.device
        # "max_obj_inds": object index of the object with the highest score at each location
        max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
        # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
        batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
        keep = max_obj_inds == batch_obj_inds
        # suppress overlapping regions' scores below -10.0 so that the foreground regions
        # don't overlap (here sigmoid(-10.0)=4.5398e-05)
        pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
        return pred_masks

<1> sigmoid函数的缩放因子和偏置

sigmoid_scale_for_mem_encsigmoid_bias_for_mem_enc 是用于调整Sigmoid 函数 的两个参数,分别控制 Sigmoid 函数的 缩放因子偏置。它们的作用是调整 Sigmoid 函数的形状和输出范围,从而影响模型的行为。

  • 缩放因子和偏置用于对 Sigmoid 函数的输入 ( x ) 进行缩放和平移,公式可以写成:

σ ( scale ⋅ x + bias ) = 1 1 + e − ( scale ⋅ x + bias ) \sigma(\text{scale} \cdot x + \text{bias}) = \frac{1}{1 + e^{-(\text{scale} \cdot x + \text{bias})}} σ(scalex+bias)=1+e(scalex+bias)1

  • 标准 Sigmoid 函数
    σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1
    这是标准的 Sigmoid 函数,输出范围是 ( 0, 1)。

  • 缩放因子的作用

    • 缩放因子 ( scale ) 用于调整 Sigmoid 函数的斜率。
    • 当 (scale > 1 ) 时,Sigmoid 函数的斜率变得更陡峭,输出对输入的变化更敏感。
    • 当 (scale < 1 ) 时,Sigmoid 函数的斜率变得更平缓,输出对输入的变化更不敏感。
    • 当 (scale = 1 ) 时,Sigmoid 函数保持不变。
  • 偏置的作用

    • 偏置 ( bias ) 用于对 Sigmoid 函数的输入进行平移。
    • 当 ( bias > 0 ) 时,Sigmoid 函数的输出整体向右平移。
    • 当 (bias < 0 ) 时,Sigmoid 函数的输出整体向左平移。
    • 当 ( bias = 0 ) 时,Sigmoid 函数保持不变。
  • 示例

    假设 ( scale = 2.0 ) 和 (bias = 0.5 ),则 Sigmoid 函数变为:

σ ( 2.0 ⋅ x + 0.5 ) = 1 1 + e − ( 2.0 ⋅ x + 0.5 ) \sigma(2.0 \cdot x + 0.5) = \frac{1}{1 + e^{-(2.0 \cdot x + 0.5)}} σ(2.0x+0.5)=1+e(2.0x+0.5)1

  • 对于不同的输入 ( x ),输出如下:

    • 当 ( x = 0 ) 时,输出约为 ( 0.62 )。

    • 当 ( x = 1 ) 时,输出约为 ( 0.92 )。

    • 当 ( x = -1 ) 时,输出约为 ( 0.18 )。

可以看出,缩放因子 ( 2.0 ) 使 Sigmoid 函数的斜率变得更陡峭,而偏置 ( 0.5 ) 使输出整体向右平移。

<2>memory bank的步长
memory_temporal_stride_for_eval=1,
  • 时间步长决定了 记忆库中存储哪些帧。具体来说:
    • 如果 memory_temporal_stride_for_eval=1,则每一帧都会被纳入记忆库。
    • 如果 memory_temporal_stride_for_eval=r(其中 r > 1),则每隔 r 帧才会有一帧被纳入记忆库。
<3>对象指针
  • 对象指针是一个 向量或嵌入(embedding),用于表示目标对象的位置、特征或状态
  • 对象指针的主要作用是:
    • 定位目标对象:帮助模型在复杂的场景中定位目标对象。
    • 跨帧跟踪目标:在视频序列中,帮助模型跨帧跟踪目标对象。
    • 提取目标特征:从目标对象中提取特征,用于后续处理。
  • 视频处理任务序列模型任务 中,模型通常需要从历史帧中提取信息(即记忆)来帮助处理当前帧。use_obj_ptrs_in_encoder 的作用是控制在编码器中,是否使用 对象指针 来跨帧提取信息:
    • 如果 use_obj_ptrs_in_encoder=True,则在编码器中使用对象指针,跨帧交叉注意力到其他帧的对象指针。
    • 如果 use_obj_ptrs_in_encoder=False,则在编码器中不使用对象指针,不跨帧交叉注意力到其他帧的对象指针。
<4>torch.nn.Module
  1. 初始化方法 (__init__)
    • 用于定义模块的子模块和参数。
    • 例如,可以在 __init__ 方法中定义卷积层、线性层等子模块。
  2. 前向传播方法 (forward)
    • 用于定义模块的前向传播逻辑。
    • 需要在子类中重写这个方法,以实现具体的前向传播计算。
  3. 参数管理
    • 提供了 parameters() 方法,用于返回模块的所有参数。
    • 提供了 named_parameters() 方法,用于返回模块的所有参数及其名称。
  4. 子模块管理
    • 提供了 children() 方法,用于返回模块的所有子模块。
    • 提供了 named_children() 方法,用于返回模块的所有子模块及其名称。
    • 提供了 modules() 方法,用于返回模块及其所有子模块。
    • 提供了 named_modules() 方法,用于返回模块及其所有子模块及其名称。
  5. 训练和评估模式
    • 提供了 train() 方法,用于将模块设置为训练模式。
    • 提供了 eval() 方法,用于将模块设置为评估模式。
  6. 保存和加载模型
    • 提供了 state_dict() 方法,用于返回模块的状态字典(包含所有参数和缓冲区)。
    • 提供了 load_state_dict() 方法,用于从状态字典加载模块的状态。
<5>ROI(region of interest)

1. ROI Pooling 的作用
ROI Pooling 的目的是将 不同大小的 ROI(感兴趣区域) 映射到 固定大小的特征图 上。它的输入是:

  • 一个特征图(通常是卷积神经网络的输出)。
  • 一组 ROI(感兴趣区域,用边界框表示)。

它的输出是:

  • 每个 ROI 对应的固定大小的特征图。

2. ROI Pooling 的步骤
ROI Pooling 的实现过程可以分为以下几步:

(1) 输入特征图和 ROI

  • 特征图:形状为 [C, H, W],其中:
    • C 是通道数。
    • HW 是特征图的高度和宽度。
  • ROI:用 (x1, y1, x2, y2) 表示,其中:
    • (x1, y1) 是 ROI 的左上角坐标。
    • (x2, y2) 是 ROI 的右下角坐标。

(2) 将 ROI 映射到特征图上

  • 由于特征图的分辨率通常低于原始图像,因此需要将 ROI 的坐标从 原始图像空间 映射到 特征图空间

  • 映射公式为:
    x feature = x image × spatial_scale x_{\text{feature}} = x_{\text{image}} \times \text{spatial\_scale} xfeature=ximage×spatial_scale

    y feature = y image × spatial_scale y_{\text{feature}} = y_{\text{image}} \times \text{spatial\_scale} yfeature=yimage×spatial_scale

    其中,spatial_scale 是特征图与原始图像的分辨率比例。

(3) 划分 ROI 为固定大小的网格

  • 将 ROI 划分为 pooled_height x pooled_width 的网格(例如 7x7)。

  • 每个网格的大小为:
    bin_size_h = h pooled_height \text{bin\_size\_h} = \frac{h}{\text{pooled\_height}} bin_size_h=pooled_heighth

    bin_size_w = w pooled_width \text{bin\_size\_w} = \frac{w}{\text{pooled\_width}} bin_size_w=pooled_widthw

    其中,hw 是 ROI 的高度和宽度。

(4) 对每个网格进行最大池化

  • 对每个网格内的特征值进行 最大池化(Max Pooling),得到固定大小的特征图。

3. 示例说明
假设:

  • 特征图的形状为 [1, 256, 64, 64]
  • ROI 的坐标为 [0.1, 0.2, 0.3, 0.4]
  • spatial_scale=1.0,即 ROI 的坐标已经映射到特征图空间。
  • output_size=(7, 7),即输出的特征图大小为 7x7

(1) 映射 ROI 到特征图

  • ROI 的坐标为 [0.1, 0.2, 0.3, 0.4],映射到特征图空间后为 [6.4, 12.8, 19.2, 25.6](假设特征图大小为 64x64)。

(2) 划分 ROI 为网格

  • ROI 的高度为 19.2 - 6.4 = 12.8,宽度为 25.6 - 12.8 = 12.8

  • 每个网格的大小为:
    bin_size_h = 12.8 7 ≈ 1.83 \text{bin\_size\_h} = \frac{12.8}{7} \approx 1.83 bin_size_h=712.81.83

    bin_size_w = 12.8 7 ≈ 1.83 \text{bin\_size\_w} = \frac{12.8}{7} \approx 1.83 bin_size_w=712.81.83

(3) 对每个网格进行最大池化

  • 对每个网格内的特征值进行最大池化,得到 7x7 的特征图。

4. 代码实现
以下是 ROI Pooling 实现对象指针:

import torch
import torch.nn as nn


# 目标检测结果 (假设形状为 [1, 4])
bbox = torch.tensor([[0.1, 0.2, 0.3, 0.4]])  # 目标对象的边界框 [x1, y1, x2, y2]

# 图像特征 (假设形状为 [1, 256, 64, 64])
image_features = torch.randn(1, 256, 64, 64)  # 1 个图像,特征维度为 256

# 生成对象指针
roi_pool = nn.ROIPool(output_size=(7, 7), spatial_scale=1.0)
obj_ptr = roi_pool(image_features, bbox)  # 提取目标对象的特征

print(obj_ptr.shape)  # 输出形状: [1, 256, 7, 7]
<6>hasattr
  • hasattr 是 Python 中的一个 内置函数,用于 检查对象是否具有指定的属性。它的作用是判断一个对象是否包含某个属性或方法,从而避免因访问不存在的属性而引发错误
hasattr(object, name)
  • object:要检查的对象。
  • name:要检查的属性或方法的名称(以字符串形式传递)。
  • 返回值:如果对象包含该属性或方法,返回 True;否则返回 False

三、代码运行

1.环境安装
  • 搭建虚拟环境

    • 选择python3.11版本
    • 在pytorch官网根据电脑配置安装对应版本
    • 在项目文件下输入终端命令
    pip install -e ".[notebooks]"
    
    • 下载过程较久
  • 在项目官网下载对应的checkpoints

2.运行脚本
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import cv2

# 使用 bfloat16 加速计算
torch.autocast(device_type="cuda", dtype=torch.float16).__enter__()

# 如果是 Ampere 架构的 GPU,启用 TF32
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

def apply_color_mask(image, mask, color, color_dark=0.5):
    """对掩体进行赋予颜色"""
    for c in range(3):
        image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
    return image

def show_anns(anns, borders=True):
    """可视化掩码"""
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask
        if borders:
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            # 平滑轮廓
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1)

    ax.imshow(img)

# 指定输出目录
output_dir = "xxx"#图片输出目录
os.makedirs(output_dir, exist_ok=True)  # 如果目录不存在,则创建

# 加载图片
print("Loading image...")
image_path = 'xxx'#图片路径
if not os.path.exists(image_path):
    print(f"Error: Image file not found at {image_path}")
else:
    image = Image.open(image_path)
    image = np.array(image.convert("RGB"))
    print("Image loaded.")

# 加载模型
print("Loading model...")
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
sam2_checkpoint = "sam2.1_hiera_tiny.pt" #对应checkpoints路径
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
if not os.path.exists(sam2_checkpoint):
    print(f"Error: Model checkpoint not found at {sam2_checkpoint}")
if not os.path.exists(model_cfg):
    print(f"Error: Model config file not found at {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device='cuda', apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2)
print("Model loaded.")

# 生成掩码
print("Generating masks...")
masks = mask_generator.generate(image)
if len(masks) == 0:
    print("Error: No masks generated. Check the image and model.")
else:
    print(f"Generated {len(masks)} masks.")

# 处理并保存每张掩码结果
image_select = image.copy()
for i in range(len(masks)):
    color = tuple(np.random.randint(0, 256, 3).tolist())  # 随机颜色
    selected_mask = masks[i]['segmentation']
    selected_image = apply_color_mask(image_select, selected_mask, color)
    output_path = os.path.join(output_dir, f"res_{i}.jpg")  # 指定输出路径
    cv2.imwrite(output_path, selected_image)
    print(f"Saved: {output_path}")

# 如果需要可视化掩码,可以取消注释以下代码
# plt.figure(figsize=(20, 20))
# plt.imshow(image)
# show_anns(masks)
# plt.axis('off')
# plt.show()
<think>好的,用户的问题是关于如何在SAM2代码中实现构建SAM的功能,需要教程。首先,我需要回顾用户提供的引用内容。引用[1]中有具体的代码示例,展示如何导入build_sam2函数,设置检查点和配置文件,然后构建模型和生成器。用户可能希望得到步骤化的指导,类似于示例中的代码结构。 用户的需求可能包括正确导入模块、配置参数、选择模型文件以及设置设备等步骤。需要确保用户正确安装相关库和环境,这可能涉及到安装SAM2的特定版本或其他依赖项。同时,用户提到的“build_sam”功能可能需要参考SAM2的官方文档或源码结构,确保参数正确。 可能遇到的问题包括路径错误,比如模型文件和配置文件的路径是否正确,以及设备设置是否正确(如CUDA是否可用)。此外,用户可能需要了解如何调整生成器的参数,如points_per_side等,以适应不同的任务需求。需要提醒用户注意这些参数的作用和调整方法。 另外,用户提供的引用[2]提到SAM2主要用于全自动分割,仅考虑单点输入。因此,在构建模型时,可能需要确保不引入交互式分割相关的参数,如掩码或边界框输入。需要强调配置文件的正确选择,如sam2.1_hiera_l.yaml,以及模型检查点的版本匹配。 总结回答的结构应该是步骤化的,从导入模块开始,到设置参数、构建模型、创建生成器,并提供示例用法。同时,添加注意事项,如文件路径、设备选择、参数调整和性能优化。最后,生成相关问题,帮助用户深入理解SAM2的使用和优化方法。</think>### SAM2构建模型教程 在SAM2代码中实现`build_sam`功能需要以下关键步骤: #### 1. 模块导入与依赖准备 首先需导入SAM2的核心模块: ```python from sam2.build_sam import build_sam2 from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator ``` 需确保已安装SAM2库及相关依赖,推荐使用PyTorch 1.10+环境[^2]。 #### 2. 配置参数设置 ```python sam2_checkpoint = "sam2.1_hiera_large.pt" # 模型权重文件 model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" # 配置文件路径 device = "cuda" if torch.cuda.is_available() else "cpu" # 设备选择 ``` #### 3. 模型构建过程 ```python # 构建基础模型 sam2 = build_sam2( model_cfg=model_cfg, checkpoint=sam2_checkpoint, device=device, apply_postprocessing=False # 是否启用后处理 ) # 初始化自动掩码生成器 mask_generator = SAM2AutomaticMaskGenerator( model=sam2, points_per_side=32, # 采样点密度 pred_iou_thresh=0.88 # 预测置信度阈值 ) ``` #### 4. 功能调用示例 ```python masks = mask_generator.generate("input_image.jpg") # 输入图像路径或numpy数组 ``` ### 关键参数说明 | 参数 | 作用 | 推荐值 | |------|------|--------| | `points_per_side` | 控制采样点密度 | 32-64 | | `pred_iou_thresh` | 过滤低质量预测 | 0.85-0.95 | | `stability_score_thresh` | 稳定性评分阈值 | 0.92-0.98 | ### 注意事项 1. 文件路径需确保正确,建议使用绝对路径 2. 若使用CPU推理,需设置`device="cpu"` 3. 大型图像处理建议启用`apply_postprocessing=True`[^1] 4. 不同模型版本需对应匹配配置文件(如sam2.1需hiera_l.yaml)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值