论文笔记(八十四)Shelving, Stacking, Hanging: Relational Pose Diffusion for Multi-modal Rearrangement(三)

文章概括

引用:

@inproceedings{simeonov2023shelving,
  title={Shelving, Stacking, Hanging: Relational Pose Diffusion for Multi-modal Rearrangement},
  author={Simeonov, Anthony and Goyal, Ankit and Manuelli, Lucas and Lin, Yen-Chen and Sarmiento, Alina and Garcia, Alberto Rodriguez and Agrawal, Pulkit and Fox, Dieter},
  booktitle={Conference on Robot Learning},
  pages={2030--2069},
  year={2023},
  organization={PMLR}
}
Simeonov, A., Goyal, A., Manuelli, L., Lin, Y.C., Sarmiento, A., Garcia, A.R., Agrawal, P. and Fox, D., 2023, December. Shelving, Stacking, Hanging: Relational Pose Diffusion for Multi-modal Rearrangement. In Conference on Robot Learning (pp. 2030-2069). PMLR.

主页:https://anthonysimeonov.github.io/rpdiff-multi-modal/
原文: https://arxiv.org/abs/2307.04751
代码、数据和视频: https://github.com/anthonysimeonov/rpdiff


系列文章:
请在 《 《 文章 》 》 专栏中查找



宇宙声明!


引用解析部分属于自我理解补充,如有错误可以评论讨论然后改正!


摘要

我们提出了一个系统,用于重新排列场景中的物体,以实现所需的物体–场景放置关系,例如将书插入书架的一个开放槽中。 该流程能够推广到新的几何形状、姿态以及场景和物体的布局,并通过示范进行训练,直接在3D点云上运行。我们的系统克服了给定场景存在许多在几何上相似的重排解决方案所带来的挑战。通过利用迭代姿态去噪训练过程,我们能够拟合多模态示范数据并产生多模态输出,同时保持精确和准确。我们还展示了在条件于相关的局部几何特征、同时忽略那些损害泛化性和精确性的不相关全局结构时的优势。我们在仿真和现实世界中,通过三个不同的重排任务验证了我们的方法,这些任务都需要处理多模态并在物体形状与姿态上实现泛化。项目网站、代码和视频: https://anthonysimeonov.github.io/rpdiff-multi-modal

关键词: 物体重排、多模态、操作、点云


train_full.py 代码解析(目录:rpdiff/src/rpdiff/training/train_full.py)

if name == “main”:


if __name__ == "__main__":
  • 作用
    判断当前模块是否作为“主程序”被直接运行。

  • 为什么要这么写
    避免他人在别的脚本里 import train_full 时,不必要地执行后续初始化逻辑。

  • 具体示例

    • 直接在命令行执行 python train_full.py 时,__name__"__main__",条件成立,执行后续解析、训练逻辑。
    • 若在 other.py 中写 import train_full,此处不执行。

    """Parse input arguments"""
  • 作用
    这是一个独立的字符串字面量(docstring),因不在函数或模块顶层位置,对程序行为无影响,仅作注释。
  • 为什么要这么写
    标明下面几行是在“解析输入参数”,方便阅读。

    parser = argparse.ArgumentParser()
  • 作用
    创建一个 ArgumentParser 对象,用于定义和解析命令行参数。
  • 为什么要这么写
    Python 标准库中推荐的命令行参数解析方式,易于扩展和生成帮助信息。
  • 示例
    parser 此时为空,准备后续通过 add_argument 添加各种参数。

    parser.add_argument('-c', '--config_fname', type=str, required=True, help='Name of config file')
  • 作用
    定义一个必须参数 -c--config_fname,接受字符串,用于指定配置文件名。

    • 短写:-c 文件路径;是 “短选项(short flag)”
    • 长写:--config_fname 文件路径;是 “长选项(long flag)”
    • 两者指向同一个参数。当你在命令行中写 -c experiment1.yaml--config_fname experiment1.yaml,程序就会把 “experiment1.yaml” 赋值给 args.config_fname。
  • 为什么要这么写
    训练脚本往往依赖一个配置文件(如 YAML/JSON)来加载超参等。设为 required=True 强制用户提供。

  • 示例
    在命令行输入 -c experiment1.yaml,则解析后 args.config_fname == "experiment1.yaml"


    parser.add_argument('-d', '--debug', action='store_true', help='If True, run in debug mode')
  • 作用
    定义一个可选布尔开关 -d/--debug。若出现此标志,args.debug=True,否则为 False

    • 默认值:args.debug 会被初始化为 False
    • 如果在命令行出现 -d/--debug:则 args.debug 会被设成 True
  • 为什么要这么写
    方便在调试时打开更详细的日志、可视化或数据校验。

  • 示例

    • python train.py -c cfg.yaml -dargs.debug == True
    • python train.py -c cfg.yamlargs.debug == False

    parser.add_argument('-dd', '--debug_data', action='store_true', help='If True, run data loader in debug mode')
  • 作用
    同上,但专门控制数据加载部分的“调试模式”(例如只加载少量样本)。
  • 为什么要这么写
    数据加载可能是瓶颈,调试时只想快速跑过几条数据。
  • 示例
    --debug_data 打开后,args.debug_data == True,数据加载模块中可判定为“只取 16 条而非全部”。

    parser.add_argument('-p', '--port_vis', type=int, default=6000, help='Port for ZMQ url (meshcat visualization)')
  • 作用
    定义一个整数型可选参数 -p/--port_vis(缺省 6000),指定 MeshCat 可视化服务器的端口。

  • 为什么要这么写
    MeshCat 默认端口是 6000,但用户可以通过命令行改成 7000、8000 等,避免端口冲突。

  • 示例

    • 默认:args.port_vis == 6000
    • 用户指定 -p 7000args.port_vis == 7000

    parser.add_argument('-s', '--seed', type=int, default=0, help='Random seed')
  • 作用
    接收一个整数随机种子参数,缺省为 0。
  • 为什么要这么写
    实验可复现性需要固定随机数种子,且脚本运行时传不同种子可做多次实验对比。
  • 示例
    --seed 42args.seed == 42,训练和数据加载中所有 random.seed(42), np.random.seed(42)

    parser.add_argument('-l', '--local_dataset_dir', type=str, default=None,
                        help='If the data is saved on a local drive, pass the root of that directory here')
  • 作用
    可选地指定本地数据集所在根目录,若不传则使用配置文件中的默认路径。
  • 为什么要这么写
    当数据下载在网络存储或集群不同节点时,用户可显式指定本地缓存目录。
  • 示例
    --local_dataset_dir /mnt/data/datasetsargs.local_dataset_dir == "/mnt/data/datasets"

    parser.add_argument('-r', '--resume', action='store_true',
                        help='If set, resume experiment (required to be set in config as well)')
  • 作用
    布尔开关,若设置则脚本尝试从上一次中断的 checkpoint 恢复训练。
  • 为什么要这么写
    长周期训练易被打断,提供自动恢复能力。
  • 示例
    --resumeargs.resume == True,后续代码中会读取 train_args['resume'] 来决定加载 checkpoint。

    parser.add_argument('-m', '--meshcat', action='store_true',
                        help='If set, run with meshcat visualization (required to be set in config as well)')
  • 作用
    布尔开关,若打开则脚本内部启用 MeshCat 实时三维可视化。
  • 为什么要这么写
    可选择性地打开或关闭可视化,节省无视觉需求时的资源。
  • 示例
    --meshcatargs.meshcat == True

    args = parser.parse_args()
  • 作用
    解析命令行传入的所有参数,将结果存入 args(一个 Namespace 对象)。

  • 为什么要这么写
    至此,所有用户输入都被标准化到 args.xxxx 中,方便后续使用。

  • 示例
    若执行 python train_full.py -c cfg.yaml -d -p 7000,则

    args.config_fname == "cfg.yaml"
    args.debug       == True
    args.port_vis    == 7000
    args.seed        == 0       # 默认
    args.resume      == False   # 默认
    

    train_args = config_util.load_config(
        osp.join(path_util.get_train_config_dir(), args.config_fname)
    )
  • 作用

    1. path_util.get_train_config_dir() 返回项目中“训练配置文件”所在文件夹路径(如 "/home/user/project/configs/train")。
    2. osp.join(...) 拼出完整文件路径,例如 "/home/user/project/configs/train/exp1.yaml"
    3. config_util.load_config(path) 读取该 YAML/JSON 文件并解析成 Python 字典 dict
  • 为什么要这么写
    解耦文件夹位置与文件名,集中管理路径;同时统一从配置文件加载超参。

  • 具体示例

    • get_train_config_dir()"/project/configs/train"

    • args.config_fname = "demo.yaml"

    • 最终 load_config("/project/configs/train/demo.yaml")

      {
        "model": "resnet18",
        "lr": 0.001,
        "epochs": 100,
        # … 其他超参 …
      }
      

    train_args['debug'] = args.debug
  • 作用
    将命令行解析出的布尔 debug 标志写回配置字典中,覆盖配置文件中同名字段。
  • 为什么要这么写
    使得配置中的 debug 与命令行保持一致——命令行优先级更高。
  • 示例
    args.debug == True,则 train_args["debug"] = True

    train_args['debug_data'] = args.debug_data
    train_args['port_vis']   = args.port_vis
    train_args['seed']       = args.seed
    train_args['local_dataset_dir'] = args.local_dataset_dir
    train_args['meshcat_ap'] = args.meshcat
    train_args['resume_ap']  = args.resume
  • 作用
    同上,把所有命令行参数映射并覆盖到配置字典 train_args 中。
  • 为什么要这么写
    统一在一个 train_args 结构里管理所有运行时参数(无论是配置文件还是命令行传入)。
  • 示例
    如果用户未传 --local_dataset_dir,则 args.local_dataset_dir is Nonetrain_args['local_dataset_dir'] 也置为 None,此后代码会用配置文件或默认逻辑自动选择数据路径。

    train_args = config_util.recursive_attr_dict(train_args)
  • 作用
    把普通字典(dict)转换成“属性字典”/“AttrDict”,可以通过 train_args.debugtrain_args.lr 访问,而不仅限于 train_args['debug']

  • 为什么要这么写

    1. 阅读性好:args.seedargs['seed'] 简洁。
    2. 统一类型:函数参数 main 中可假定 train_args 支持属性访问。
  • 示例

    • 之前:train_args['epochs'] == 100
    • 之后:train_args.epochs == 100

    main(train_args)
  • 作用
    调用主流程函数 main,传入所有整理好的运行参数 train_args

  • 为什么要这么写
    把脚本的核心逻辑封装在 main 函数里,便于测试、复用和导入。

  • 输入

    • train_args:一个支持属性访问的参数对象,字段示例有

      train_args.config_fname          # str,如 "demo.yaml"
      train_args.debug                 # bool
      train_args.lr                    # float,学习率
      train_args.model                 # str,如 "resnet18"
      train_args.local_dataset_dir     # Optional[str]
      # … 等等 …
      

核心大函数/类汇总

名称作用输入输出/返回值常见接口/方法
path_util.get_train_config_dir()返回训练配置文件目录路径str,如 "/path/to/configs/train"可扩展:set_train_config_dir(path: str)
config_util.load_config(path)从 YAML/JSON 配置文件加载参数到 dictpath: strDict[str, Any]save_config(dict, path)
config_util.recursive_attr_dict(d)将普通 dict 递归地转为支持 .attr 访问的结构d: Dict[str, Any]AttrDictto_dict(), 支持嵌套
main(train_args)脚本主流程:数据加载、模型构建、训练、可视化、保存等train_args: AttrDictNoneTrainer, Evaluator, Visualizer 等类协作调用


def main(args: config_util.AttrDict):

  • 用途main 是脚本的核心入口,负责:

    1. 设置随机种子,保证可复现性;
    2. 搭建实验目录结构,记录本次运行的代码、配置、日志;
    3. 加载数据集,构建训练/验证数据管道;
    4. 初始化各子网络(粗糙可用性、位姿微调、成功分类),以及它们对应的损失函数和优化器;
    5. (可选)从 checkpoint 恢复训练状态
    6. 最后调用 train 函数真正执行训练流程。
  • 输入

    • args:一个 AttrDict 对象,属性来源于命令行参数和配置文件的合并,典型字段有:

      args.seed                  # 随机种子,例如 0 或 42
      args.experiment            # 实验相关子配置,包含 batch_size、num_iterations、logdir、resume_iter 等
      args.data                  # 数据相关配置,包含 data_root、dataset_path、voxel_grid、rot_grid_bins 等
      args.model                 # 三个子模型配置,包含 refine_pose、success、coarse_affordance 等
      args.loss                  # 对应三个子模型的损失设置
      args.optimizer             # 对应三个子模型的优化器设置
      args.meshcat_ap            # 是否打开 meshcat 可视化(命令行 `--meshcat`)
      args.port_vis              # meshcat 端口,例如 6000
      args.local_dataset_dir     # 本地数据根目录,例如 "/mnt/data/datasets"
      args.debug_data            # 数据加载调试开关
      
  • 输出

    • 无返回值,但会在 logdir 下生成模型、日志、可视化结果,并在屏幕/TensorBoard 打印训练信息。

1. 设置随机种子,保证可复现性

    random.seed(args.seed)
  • 作用:为 Python 标准库的 random 模块设置种子。
  • 示例:若 args.seed = 42,则 random.seed(42);后续调用 random.random() 会固定输出序列。
  • 为什么要这么写:确保使用 random 生成的随机数(如数据打乱顺序)在每次运行中一致。
    torch.manual_seed(args.seed)
  • 作用:为 PyTorch CPU 和单卡 GPU 设置全局随机种子。
  • 示例:同样用 42,后续所有 torch.randn()torch.randperm() 等操作可复现。
  • 为什么要这么写:深度学习训练中的权重初始化、数据增强等都依赖随机,需要固定。
    np.random.seed(args.seed)
  • 作用:为 NumPy 随机数生成器设置种子。
  • 示例np.random.seed(42) 后,np.random.rand(3) 每次均输出相同数组。
  • 为什么要这么写:如果数据预处理或其它步骤中使用了 NumPy 随机,需保持一致性。

2. 搭建实验目录与日志记录

    ##############################################
    # Setup basic experiment params
  • 注释:标志“搭建基础实验参数”这一逻辑块的开始。
    logdir = osp.join(
        path_util.get_rpdiff_model_weights(), 
        args.experiment.logdir, 
        args.experiment.experiment_name)
  • 作用:拼出主存储目录 logdir,用于保存模型权重、TensorBoard 日志等。

  • 示例

    • path_util.get_rpdiff_model_weights()"/home/user/rpdiff_weights"
    • args.experiment.logdir = "logs"
    • args.experiment.experiment_name = "exp1"
    • 最终 logdir = "/home/user/rpdiff_weights/logs/exp1"
  • 为什么要这么写:统一管理模型和日志,便于查找和对比不同实验。

    util.safe_makedirs(logdir)
  • 作用:若 logdir 不存在,递归创建目录;若已存在则忽略错误。
  • 示例"/home/user/rpdiff_weights/logs/exp1" 不存在时创建。
  • 为什么要这么写:避免后续写文件时报错,同时保留已有内容。
    # Set up experiment run/config logging
    nowstr = datetime.datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
  • 作用:获取当前时间字符串,用于区分多次运行。

  • 示例:假设当前时间 2025-05-07 14:23:45,则

    nowstr == "05-07-2025_14-23-45"
    
  • 为什么要这么写:在同一 exp1 文件夹下,多次启动时用不同子文件夹保存。

    run_logs = osp.join(logdir, 'run_logs')
  • 作用:在主目录下拼出“运行日志”子目录路径。
  • 示例"/home/user/.../exp1/run_logs"
  • 为什么要这么写:将所有运行记录集中在 run_logs 里,清晰分层。
    util.safe_makedirs(run_logs)
  • 同上:递归创建 run_logs
    run_log_folder = osp.join(run_logs, nowstr)
  • 作用:拼出当前这次运行的唯一日志文件夹名。
  • 示例"/home/.../exp1/run_logs/05-07-2025_14-23-45"
  • 为什么要这么写:分次管理,避免文件冲突。
    util.safe_makedirs(run_log_folder)
  • 同上:创建本次运行日志目录。
    # copy everything we would like to know about this run in the run log folder
    for fn in os.listdir(os.getcwd()):
        if not (fn.endswith('.py') or fn.endswith('.sh') or fn.endswith('.bash')):
            continue
        log_fn = osp.join(run_log_folder, fn)
        shutil.copy(fn, log_fn)
  • 作用:把当前工作目录下所有 .py.sh.bash 脚本文件复制到运行日志目录,用于版本跟踪。
  • 示例:若项目根有 train.pyutils.pyrun.sh,则都拷贝过去。
  • 为什么要这么写:确保当时的代码快照被保留,方便日后复现或排查。
    full_cfg_dict = copy.deepcopy(config_util.recursive_dict(args))
  • 作用:将属性字典 args 转回普通 dict,并做深拷贝。

  • 示例

    full_cfg_dict = {
      "seed": 42,
      "experiment": {"logdir": "logs", "experiment_name": "exp1", ...},
      "model": {...}, ...
    }
    
  • 为什么要这么写:把合并后、最终运行时的所有配置保存为纯文本。

    full_cfg_fname = osp.join(run_log_folder, 'full_exp_cfg.txt')
  • 作用:拼出配置保存文件名。
  • 示例".../05-07-2025_14-23-45/full_exp_cfg.txt"
    json.dump(full_cfg_dict, open(full_cfg_fname, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
  • 作用:以 UTF-8、漂亮缩进格式将配置字典写入磁盘。

  • 示例:文件内容类似:

    {
        "seed": 42,
        "experiment": {
            "logdir": "logs",
            "experiment_name": "exp1",
            ...
        },
        ...
    }
    
  • 为什么要这么写:直观查看与脚本外部参数结合后的完整配置,增强透明度。


3. 可视化(MeshCat)初始化

    if args.experiment.meshcat_on and args.meshcat_ap:
  • 作用:检查配置与命令行是否同时打开了 meshcat 可视化。

  • 示例

    • args.experiment.yaml 中有 meshcat_on: true,且命令行传了 --meshcat,条件成立。
  • 为什么要这么写:双重开关确保用户有意图使用可视化,避免无意占用端口。

        zmq_url=f'tcp://127.0.0.1:{args.port_vis}'
  • 作用:按格式拼出 ZMQ 通信 URL。
  • 示例"tcp://127.0.0.1:6000"
        mc_vis = meshcat.Visualizer(zmq_url=zmq_url)
  • 作用:创建一个 MeshCat 可视化客户端。

  • 输入

    • zmq_url:ZMQ 监听地址;
  • 输出:一个 meshcat.Visualizer 对象,可通过 mc_vis['scene'] 操作 3D 场景。

  • 为什么要这么写:将过程中的点云、网格、坐标系等实时推送到浏览器查看。

  • 常见接口

    • mc_vis['scene'][<name>].set_object(...):添加模型;
    • mc_vis['scene'].delete():清空场景。
        mc_vis['scene'].delete()
  • 作用:删除场景中已有对象,保证每次启动时画布是空白的。
  • 为什么要这么写:避免重叠显示上次运行残留。
    else:
        mc_vis = None
  • 作用:若关闭可视化,则将 mc_vis 设为 None,下游代码判断后跳过可视化调用。
  • 为什么要这么写:统一接口,避免每次都判断两个标志。
    # prepare dictionary for extra kwargs in train function
    train_kwargs = {}
  • 作用:预先创建一个空字典,后续若有“第二模型”或自定义参数则填入,最后传给 train(...)
  • 为什么要这么写:保持 train 函数接口统一,方便可选输入。

4. 准备数据集与 DataLoader

    ##############################################
    # Prepare dataset and dataloader
  • 注释:开始数据加载模块。
    data_args = args.data
  • 作用:局部变量引用,减少后续多次 args.data 的冗余书写。
  • 示例data_args.data_root = "shape_data"data_args.dataset_path = "task1"
    if osp.exists(str(args.local_dataset_dir)):
  • 作用:判断用户是否通过命令行指定了一个真实存在的本地数据目录。
  • 示例:用户传 --local_dataset_dir /mnt/data/datasets,若该目录存在,则为真。
        dataset_path = osp.join(
            args.local_dataset_dir, 
            data_args.data_root,
            data_args.dataset_path)
  • 作用:若本地目录存在,则从本地拼出二级路径。
  • 示例"/mnt/data/datasets/shape_data/task1"
    else:
        dataset_path = osp.join(
            path_util.get_rpdiff_data(), 
            data_args.data_root,
            data_args.dataset_path)
  • 作用:否则从默认网络/集群路径加载数据。
  • 示例"/home/user/rpdiff_data/shape_data/task1"
    assert osp.exists(dataset_path), f'Dataset path: {dataset_path} does not exist'
  • 作用:若最终路径不存在,抛出错误并打印具体路径,帮助定位问题。
  • 示例:若路径错了,就会报 AssertionError: Dataset path: … does not exist
    train_dataset = dataio.FullRelationPointcloudPolicyDataset(
        dataset_path, 
        data_args,
        phase='train', 
        train_coarse_aff=args.experiment.train.train_coarse_aff,
        train_refine_pose=args.experiment.train.train_refine_pose,
        train_success=args.experiment.train.train_success,
        mc_vis=mc_vis, 
        debug_viz=args.debug_data)
  • 作用:实例化训练集对象。

  • 输入

    1. dataset_path: str,数据根目录;
    2. data_args: AttrDict,包含 voxel_gridrot_grid 等预处理参数;
    3. phase='train':区分 train/val,不同阶段做数据增强;
    4. 三个布尔开关:是否返回粗可用性、位姿微调、成功标记标签;
    5. mc_vis: meshcat.Visualizer or None:可视化句柄;
    6. debug_viz: bool:调试模式下每条数据只可视化少量样本。
  • 输出:一个 PyTorch Dataset 实例,支持 __len__(), __getitem__(idx) 接口。

  • 为什么要这么写:封装数据逻辑,统一返回三类任务所需输入。

    val_dataset = dataio.FullRelationPointcloudPolicyDataset(
        dataset_path, 
        data_args,
        phase='val', 
        train_coarse_aff=args.experiment.train.train_coarse_aff,
        train_refine_pose=args.experiment.train.train_refine_pose,
        train_success=args.experiment.train.train_success,
        mc_vis=mc_vis,
        debug_viz=args.debug_data)
  • 同上,但 phase='val',通常不做随机增强、shuffle。
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=args.experiment.batch_size, 
        shuffle=True, 
        num_workers=args.experiment.num_train_workers, 
        drop_last=True)
  • 作用:用 PyTorch 自带 DataLoader 包装训练集:

    • batch_size: 如 32
    • shuffle=True: 每 epoch 打乱顺序
    • num_workers: 并行加载线程数,如 4
    • drop_last=True: 丢弃最后一个不满 batch
  • 输出:可迭代的 train_dataloader,每次返回一批数据字典。

    val_dataloader = DataLoader(
        val_dataset, 
        batch_size=2, 
        num_workers=1,
        shuffle=False, 
        drop_last=True)
  • 作用:验证集一般 batch 较小(这里写死 2),不开 shuffle,线程数少(1),drop_last 保持整批。
  • 示例:用于定期在训练过程中评估性能。
    # grab some things we need for training
    args.experiment.epochs = args.experiment.num_iterations / len(train_dataloader)
  • 作用:自动计算训练周期(epoch)数:

    epochs = 总迭代次数 每个 epoch 的 batch 数 \text{epochs} = \frac{\text{总迭代次数}}{\text{每个 epoch 的 batch 数}} epochs=每个 epoch  batch 总迭代次数

  • 示例:若 num_iterations = 10000len(train_dataloader)=312,则

    a r g s . e x p e r i m e n t . e p o c h s ≈ 10000 / 312 ≈ 32.05 args.experiment.epochs \approx 10000/312 \approx 32.05 args.experiment.epochs10000/31232.05

  • 为什么要这么写:方便后续按 epoch 进行学习率调度或日志记录。

    args.data.rot_grid_bins = train_dataset.rot_grid.shape[0]
  • 作用:将数据集实例内部生成的旋转网格维度输出回配置,用于后续模型初始化。
  • 示例:若 rot_grid(24, 3) 数组,则 rot_grid_bins = 24
  • 为什么要这么写:保持数据和模型对齐,避免手动同步出错。

5. 初始化三个子网络、损失与优化器

5.1 位姿微调模型(Pose Refinement)
    pr_type = args.model.refine_pose.type
  • 作用:读取配置中要用的微调模型类型字符串,如 "nsm_transformer"
  • 示例pr_type = "nsm_transformer"
    pr_args = config_util.copy_attr_dict(args.model[pr_type])
  • 作用:复制该模型类型对应的参数子字典,如 args.model['nsm_transformer'],用于初始化。
  • 示例pr_args = {"num_layers":4, "hidden_dim":128, ...}
    if args.model.refine_pose.get('model_kwargs') is not None:
        custom_pr_args = args.model.refine_pose.model_kwargs[pr_type]
        config_util.update_recursive(pr_args, custom_pr_args)
  • 作用:若在 args.model.refine_pose.model_kwargs 中有针对不同 pr_type 的自定义参数,则递归地合并覆盖。

  • 示例

    model:
      refine_pose:
        type: nsm_transformer
        model_kwargs:
          nsm_transformer:
            hidden_dim: 256
    

    上面就把 hidden_dim 从默认 128 更新为 256

  • 为什么要这么写:支持一次配置文件中同时保存“通用参数”与“某模型专属参数”。

    if pr_type == 'nsm_transformer':
        pr_model_cls = NSMTransformerSingleTransformationRegression
    elif pr_type == 'nsm_transformer_cvae':
        pr_model_cls = NSMTransformerSingleTransformationRegressionCVAE
    else:
        raise ValueError(f'Unrecognized: {pr_type}')
  • 作用:根据字符串选择具体的模型类。

  • 示例:若 "nsm_transformer",则

    pr_model_cls == NSMTransformerSingleTransformationRegression
    
  • 为什么要这么写:避免硬编码,仅靠 config 决定模型类型。

  • 常见接口(类构造函数):

    pr_model_cls(mc_vis, feat_dim, num_layers, hidden_dim, ...)
    

    返回一个 torch.nn.Module 子类实例。

    pr_model = pr_model_cls(
        mc_vis=mc_vis, 
        feat_dim=args.model.refine_pose.feat_dim, 
        **pr_args).cuda()
  • 作用:实例化微调模型并 .cuda() 转到 GPU。

  • 示例

    pr_model = NSMTransformerSingleTransformationRegression(
        mc_vis=None, feat_dim=64, num_layers=4, hidden_dim=128
    ).cuda()
    
  • 为什么要这么写:GPU 加速计算。

  • 输出pr_model,含 forward(...) 接口,接收点云 + 初始变换,输出细化后的位姿预测。

    pr_model_params = pr_model.parameters()
  • 作用:收集模型所有可训练参数,用于优化器。
  • 示例:生成一个 Python generator,遍历后是若干 torch.Tensor
    # loss
    pr_loss_type = args.loss.refine_pose.type
  • 作用:读取微调模型对应的损失函数类型字符串,如 "tf_chamfer"
    assert pr_loss_type in args.loss.refine_pose.valid_losses, \
           f'Loss type: {pr_loss_type} not in {args.loss.refine_pose.valid_losses}'
  • 作用:校验配置合法性,如果不在支持列表中则报错。
  • 示例:若 valid_losses = ['tf_chamfer', 'tf_chamfer_w_kldiv'],并且 pr_loss_type='mse',此处就会断言失败。
    if pr_loss_type == 'tf_chamfer':
        tfc_mqa_wrapper = losses.TransformChamferWrapper(
            l1=args.loss.tf_chamfer.l1,
            trans_offset=args.loss.tf_chamfer.trans_offset)
        pr_loss_fn = tfc_mqa_wrapper.tf_chamfer
  • 作用

    1. 实例化一个 TransformChamferWrapper,包装 Chamfer 距离损失。

      • l1: 是否加 L1 位置误差分量,布尔或系数,例如 True0.1
      • trans_offset: 平移偏移惩罚系数,如 0.01
    2. 从包装器中取出纯函数 tf_chamfer 作为损失函数。

  • 输出pr_loss_fn,接口签名类似

    loss = pr_loss_fn(pred_transforms, target_transforms)
    
  • 为什么要这么写:实现多种损失切换,并复用包装器中预处理逻辑。

    elif pr_loss_type == 'tf_chamfer_w_kldiv':
        tfc_mqa_wrapper = losses.TransformChamferWrapper(
            l1=args.loss.tf_chamfer.l1,
            trans_offset=args.loss.tf_chamfer.trans_offset,
            kl_div=True)
        pr_loss_fn = tfc_mqa_wrapper.tf_chamfer_w_kldiv
  • 作用:同上,但在 wrapper 中打开 kl_div 标志,生成带 KLDivergence 的损失函数。
    else:
        raise ValueError(f'Unrecognized: {pr_loss_fn}')
  • 防御性编程:若意外遇到未知类型,报错提醒配置错误。
    # optimizer
    pr_opt_type = args.optimizer.refine_pose.type
    assert pr_opt_type in args.optimizer.refine_pose.valid_opts, \
           f'Opt type: {pr_opt_type} not in {args.optimizer.refine_pose.valid_opt}'
  • 作用:类似损失,读取并校验优化器类型(如 "Adam""AdamW")。
    if pr_opt_type == 'Adam':
        pr_opt_cls = torch.optim.Adam 
    elif pr_opt_type == 'AdamW':
        pr_opt_cls = torch.optim.AdamW 
    else:
        raise ValueError(f'Unrecognized: {pr_opt_type}')
  • 作用:根据字符串映射到具体 PyTorch 优化器类。
    pr_opt_kwargs = config_util.copy_attr_dict(args.optimizer[pr_opt_type])
  • 作用:读取通用优化器超参,如 {"lr":0.001, "weight_decay":1e-4}
    if args.optimizer.refine_pose.get('opt_kwargs') is not None:
        custom_pr_opt_kwargs = args.optimizer.refine_pose.opt_kwargs[pr_opt_type]
        config_util.update_recursive(pr_opt_kwargs, custom_pr_opt_kwargs)
  • 作用:合并 refine_pose 专属的优化器参数覆盖通用设置。
    pr_optimizer = pr_opt_cls(pr_model_params, **pr_opt_kwargs)
  • 作用:生成微调模型的优化器实例。

  • 示例

    pr_optimizer = torch.optim.Adam(
        pr_model.parameters(), lr=0.001, weight_decay=1e-4
    )
    
  • 为什么要这么写:后续在训练循环中调用 pr_optimizer.zero_grad()pr_optimizer.step()


5.2 成功分类器(Success Classifier)

整个流程与位姿模型几乎一致,这里只列要点。

    sc_type = args.model.success.type
    sc_args = config_util.copy_attr_dict(args.model[sc_type])
    if args.model.success.get('model_kwargs') is not None:
        custom_sc_args = args.model.success.model_kwargs[sc_type]
        config_util.update_recursive(sc_args, custom_sc_args)
  • 作用:读取、复制并可能合并成功分类模型的参数。
	# model
    if sc_type == 'nsm_transformer':
        success_model_cls = NSMTransformerSingleSuccessClassifier
    else:
        raise ValueError(f'Unrecognized success model type: {sc_type}')
  • 接口NSMTransformerSingleSuccessClassifier(mc_vis, feat_dim, **sc_args) 返回 torch.nn.Moduleforward 输入点云与候选位姿,输出成功概率。
    success_model = success_model_cls(
        mc_vis=mc_vis,
        feat_dim=args.model.success.feat_dim,
        **sc_args).cuda()
    sc_model_params = success_model.parameters()
  • 示例feat_dim=64,其他参数如 num_heads=8
    # loss
    sc_loss_type = args.loss.success.type
    assert sc_loss_type in args.loss.success.valid_losses, f'Loss type: {sc_loss_type} not in {args.loss.success.valid_losses}'
    if sc_loss_type == 'bce_wo_logits':
        sc_loss_fn = losses.success_bce
  • 作用:若不带 logits 的二分类交叉熵,直接用封装的函数 success_bce(pred, label)
    elif sc_loss_type == 'bce_w_logits':
        double_batch = args.loss.bce_w_logits.double_batch_size
        batch_scalar = 2 if double_batch else 1 # in some experiments, we double the batch size for the success model
        bce_logits_wrapper = losses.BCEWithLogitsWrapper(
            pos_weight=args.loss.bce_w_logits.pos_weight,
            bs=args.experiment.batch_size*batch_scalar)
        sc_loss_fn = bce_logits_wrapper.success_bce_w_logits
    else:
        raise ValueError(f'Unrecognized: {sc_loss_type}')
  • 说明:带 logits 的版本,可指定 pos_weight、是否加倍 batch,对正负样本不均衡情况友好。
    # optimizer
    sc_opt_type = args.optimizer.success.type
    assert sc_opt_type in args.optimizer.success.valid_opts, ...
    if sc_opt_type == 'Adam':
        sc_opt_cls = torch.optim.Adam 
    elif sc_opt_type == 'AdamW':
        sc_opt_cls = torch.optim.AdamW 
    else:
        raise ValueError(f'Unrecognized: {sc_opt_type}')
    sc_opt_kwargs = config_util.copy_attr_dict(args.optimizer[sc_opt_type])
    if args.optimizer.success.get('opt_kwargs') is not None:
        custom_sc_opt_kwargs = args.optimizer.success.opt_kwargs[sc_opt_type]
        config_util.update_recursive(sc_opt_kwargs, custom_sc_opt_kwargs)
    sc_optimizer = sc_opt_cls(sc_model_params, **sc_opt_kwargs)
  • 作用:同位姿模型,初始化成功分类器的优化器。

5.3 粗糙可用性网络(Coarse Affordance)
	# model
    coarse_aff_type = args.model.coarse_aff.type
    coarse_aff_args = config_util.copy_attr_dict(args.model[coarse_aff_type])
    if args.model.coarse_aff.get('model_kwargs') is not None:
        custom_coarse_aff_args = args.model.coarse_aff.model_kwargs[coarse_aff_type]
        config_util.update_recursive(coarse_aff_args, custom_coarse_aff_args)
  • 作用:加载 coarse affordance 网络配置。
    if args.model.coarse_aff.multi_model:
        coarse_aff_args2 = config_util.copy_attr_dict(args.model[coarse_aff_type])
        if args.model.coarse_aff.get('model_kwargs2') is not None:
            custom_coarse_aff_args2 = args.model.coarse_aff.model_kwargs2[coarse_aff_type]
            config_util.update_recursive(coarse_aff_args2, custom_coarse_aff_args2)
  • 说明:若要启用“双模型”模式,则对第二个模型做同样配置,后面会并行训练。
    coarse_aff_model = CoarseAffordanceVoxelRot(
        mc_vis=mc_vis, 
        feat_dim=args.model.coarse_aff.feat_dim,
        rot_grid_dim=args.data.rot_grid_bins,
        padding=args.data.voxel_grid.padding,
        voxel_reso_grid=args.data.voxel_grid.reso_grid,
        euler_rot=args.model.coarse_aff.euler_rot,
        euler_bins_per_axis=args.model.coarse_aff.euler_bins_per_axis,
        scene_encoder_kwargs=coarse_aff_args).cuda()
  • 作用:初始化粗糙可用性网络并上 GPU。

  • 输入

    • feat_dim: 特征维度,例如 64
    • rot_grid_dim: 旋转采样点数,如 24
    • paddingvoxel_reso_grid: 体素网格大小/分辨率,例如 [32,32,32]
    • euler_roteuler_bins_per_axis: 是否启用欧拉旋转采样及每轴 bins 数
    • scene_encoder_kwargs: 上面合并后的网络超参字典
  • 输出torch.nn.Moduleforward(pointcloud) 返回每个体素的 affordance 概率。

    aff_model_params = coarse_aff_model.parameters()
  • 作用:收集参数。
	# loss
    aff_loss_type = args.loss.coarse_aff.type
    assert aff_loss_type in args.loss.coarse_aff.valid_losses, f'Loss type: {aff_loss_type} not in {args.loss.coarse_aff.valid_losses}'

    if aff_loss_type == 'voxel_affordance':
        aff_loss_fn = losses.voxel_affordance
    elif aff_loss_type == 'voxel_affordance_w_disc_rot':
        aff_loss_fn = losses.voxel_affordance_w_disc_rot
    elif aff_loss_type == 'voxel_affordance_w_disc_rot_euler':
        aff_loss_fn = losses.voxel_affordance_w_disc_rot_euler
    else:
        raise ValueError(f'Unrecognized: {aff_loss_type}')
  • 作用:根据类型切换三种可用性损失函数,分别是基础版本、带离散旋转惩罚、带欧拉旋转惩罚。
	# optimizer
    aff_opt_type = args.optimizer.coarse_aff.type
    assert aff_opt_type in args.optimizer.coarse_aff.valid_opts, f'Opt type: {aff_opt_type} not in {args.optimizer.coarse_aff.valid_opt}'

    if aff_opt_type == 'AdamW':
        aff_opt_cls = torch.optim.AdamW
    elif aff_opt_type == 'Adam':
        aff_opt_cls = torch.optim.Adam
    else:
        raise ValueError(f'Unrecognized: {aff_opt_type}')
    aff_opt_kwargs = config_util.copy_attr_dict(args.optimizer[aff_opt_type])
    if args.optimizer.coarse_aff.get('opt_kwargs') is not None:
        custom_aff_opt_kwargs = args.optimizer.coarse_aff.opt_kwargs[aff_opt_type]
        config_util.update_recursive(aff_opt_kwargs, custom_aff_opt_kwargs)
    aff_optimizer = aff_opt_cls(aff_model_params, **aff_opt_kwargs)
  • 作用:初始化 coarse affordance 的优化器。
	if args.model.coarse_aff.multi_model:
        coarse_aff_args2 = config_util.copy_attr_dict(args.model[coarse_aff_type])
        if args.model.coarse_aff.get('model_kwargs2') is not None:
            custom_coarse_aff_args2 = args.model.coarse_aff.model_kwargs2[coarse_aff_type]
            config_util.update_recursive(coarse_aff_args2, custom_coarse_aff_args2)

        voxel_reso_grid2 = args.data.voxel_grid.reso_grid
        voxel_reso_grid2 = util.set_if_not_none(voxel_reso_grid2, args.model.coarse_aff.model2.voxel_grid.reso_grid)

        padding2 = args.data.voxel_grid.padding
        padding2 = util.set_if_not_none(padding2, args.data.voxel_grid.padding)

        coarse_aff_model2 = CoarseAffordanceVoxelRot(
            mc_vis=mc_vis, 
            feat_dim=args.model.coarse_aff.feat_dim,
            rot_grid_dim=args.data.rot_grid_bins,
            padding=padding2,
            voxel_reso_grid=voxel_reso_grid2,
            scene_encoder_kwargs=coarse_aff_args2).cuda()

        aff_model_params2 = coarse_aff_model2.parameters()
        aff_loss_fn2 = aff_loss_fn
        aff_optimizer2 = aff_opt_cls(aff_model_params2, **aff_opt_kwargs)

        train_kwargs['coarse_aff_model2'] = {}
        train_kwargs['coarse_aff_model2']['model'] = coarse_aff_model2
        train_kwargs['coarse_aff_model2']['opt'] = aff_optimizer2
        train_kwargs['coarse_aff_model2']['loss_fn'] = aff_loss_fn2
        train_kwargs['coarse_aff_model2']['reso_grid'] = voxel_reso_grid2
        train_kwargs['coarse_aff_model2']['padding'] = padding2
        
    # model_dict = dict(model=model, rot=rot_model)
    # if args.debug:
    #     print('Coarse affordance model: ')
    #     print(coarse_aff_model)
    #     print('Refine pose model: ')
    #     print(pr_model)
    #     print('Success model: ')
    #     print(success_model)
  • 作用:若双模型模式,将第二套模型、优化器、损失、分辨率等都放入 train_kwargs,供后续 train 函数使用。

6. 从 Checkpoint 恢复(可选)

    ##############################################
    # Load checkpoints if resuming
  • 注释:开始恢复逻辑。
    if args.experiment.resume and args.resume_ap:
  • 作用:仅在配置文件和命令行都要求 resume=True 时才执行。
		# find the latest iteration
        ckpts = [int(val.split('model_')[1].replace('.pth', ''))
	                 for val in os.listdir(logdir)
	                 if (val.endswith('.pth') and 'latest' not in val)]
  • 作用:扫描 logdir 下所有命名为 model_XXX.pth 文件,取出数值部分作为可恢复的迭代号列表。
  • 示例:文件夹里有 model_1000.pth, model_2000.pth, 则 ckpts = [1000, 2000]
        args.experiment.resume_iter = max(ckpts)
  • 作用:选择最大的迭代号作为恢复点。
  • 示例resume_iter = 2000
    if args.experiment.resume_iter != 0:
  • 作用:若恢复迭代号不为 0,则真正加载对应 checkpoint。默认 resume_iter=0 表示从头开始。
        print(f'Resuming at iteration: {args.experiment.resume_iter}')
  • 作用:在屏幕打印恢复点,增强可见性。
        model_path = osp.join(logdir, f'model_{args.experiment.resume_iter}.pth')
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
  • 作用

    1. 拼出 checkpoint 路径,如 .../model_2000.pth
    2. 用 CPU 加载(先到 CPU 再 .cuda() 或直接在 GPU 上加载也可)。
  • 输出checkpoint 是一个 dict,典型键包括:

    {
      'refine_pose_model_state_dict': {...},
      'pr_optimizer_state_dict': {...},
      'success_model_state_dict': {...},
      'sc_optimizer_state_dict': {...},
      'coarse_aff_model_state_dict': {...},
      'aff_optimizer_state_dict': {...},
      # 如双模型还会有 _state_dict2
    }
    
        if args.experiment.train.train_refine_pose:
            pr_model.load_state_dict(checkpoint['refine_pose_model_state_dict'])
            pr_optimizer.load_state_dict(checkpoint['pr_optimizer_state_dict'])
  • 作用:若该任务开启,则分别恢复模型权重和优化器状态(包括动量、学习率调度信息等)。
        if args.experiment.train.train_success:
            success_model.load_state_dict(checkpoint['success_model_state_dict'])
            sc_optimizer.load_state_dict(checkpoint['sc_optimizer_state_dict'])
  • 同上:成功分类器恢复。
        if args.experiment.train.train_coarse_aff:
            coarse_aff_model.load_state_dict(checkpoint['coarse_aff_model_state_dict'])
            aff_optimizer.load_state_dict(checkpoint['aff_optimizer_state_dict'])
            if args.model.coarse_aff.multi_model:
                coarse_aff_model2.load_state_dict(checkpoint['coarse_aff_model_state_dict2'])
                aff_optimizer.load_state_dict(checkpoint['aff_optimizer_state_dict2'])
  • 同上:粗糙可用性网络(以及第二模型)恢复。

7. TensorBoard 日志与设备选择

    logger = SummaryWriter(logdir)
  • 作用:创建 TensorBoard 日志写入器,所有 logger.add_scalar(...) 会保存在该目录下。

  • 输入logdir,例如 "/home/.../exp1"

  • 输出logger 对象,常见接口:

    logger.add_scalar('loss/train', loss_value, global_step)
    logger.add_image('viz/pointcloud', img_tensor, global_step)
    
    it = args.experiment.resume_iter
  • 作用:把恢复的迭代号赋给 it,作为训练循环的起始步数。
  • 示例:若恢复到 2000,则从 it=2000 开始继续计数。
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        raise ValueError('Cuda not available')
  • 作用

    1. 检查是否有 GPU;
    2. 若有,则将 device 设为第一张卡 cuda:0
    3. 否则报错。
  • 为什么要这么写:脚本依赖 GPU 训练,若无 GPU 需立刻停止。


8. 配置合法性检查

    ##############################################
    # Perform other checks to ensure experiment config is valid
  • 注释:开始一些互斥/依赖参数的断言检查。
    train_exp_args = config_util.copy_attr_dict(args.experiment.train)
  • 作用:复制一份训练子配置到局部变量中,简化书写。
    if train_exp_args.refine_pose_from_coarse_pred:
        assert train_exp_args.train_coarse_aff and train_exp_args.train_refine_pose, \
               'Must be training both coarse and refine to use predictions as refinement input'
  • 作用:若想用粗糙网络的输出作为微调网络输入,则必须同时训练两者。
    assert not (train_exp_args.success_from_coarse_pred and train_exp_args.success_from_refine_pred), \
           'Cannot predict success from both refine pred and coarse pred. Please only set one of these to True'
  • 作用:不能同时使用两种预测结果来判断成功,否则逻辑冲突。
    if train_exp_args.success_from_coarse_pred: 
        assert train_exp_args.train_coarse_aff and train_exp_args.train_success, \
               'Must be training both coarse and success to use predictions as success input'
  • 作用:依赖检查。
    if train_exp_args.success_from_refine_pred: 
        assert train_exp_args.train_refine_pose and train_exp_args.train_success, \
               'Must be training both refine and success to use predictions as success input'
  • 作用:同上。

9. 启动训练

    train(
        mc_vis, 
        coarse_aff_model, pr_model, success_model, 
        aff_optimizer, pr_optimizer, sc_optimizer, 
        train_dataloader, val_dataloader, 
        aff_loss_fn, pr_loss_fn, sc_loss_fn,
        device, 
        logger, 
        logdir, 
        args,
        it, 
        **train_kwargs)
  • 作用:调用训练主循环函数 train,传入所有模型、优化器、数据管道、损失函数、设备、日志、初始迭代号,以及可选的第二模型等额外参数。

  • 输入

    1. mc_vis: MeshCat 可视化器或 None
    2. coarse_aff_model, pr_model, success_model: 三个 torch.nn.Module
    3. aff_optimizer, pr_optimizer, sc_optimizer: 三个优化器
    4. train_dataloader, val_dataloader
    5. aff_loss_fn, pr_loss_fn, sc_loss_fn: 三个损失函数
    6. device: torch.device('cuda:0')
    7. logger: SummaryWriter
    8. logdir: 主目录字符串
    9. args: 完整配置
    10. it: 起始迭代号
    11. **train_kwargs: 如多模型时的第二模型字典
  • 输出:无返回值,主要在内部执行:

    • for epoch in range(...)

      • for batch in train_dataloader

        • 依次调用三段网络前向、计算损失、反向传播、优化器 step()
        • 根据配置在每隔 N 步或每个 epoch 执行验证、可视化、保存 checkpoint。
  • 为什么要这么写:将所有准备好的资源一并传入训练函数,职责分离——main 只做准备,train 专注运行。



def train( … ):

def train(
        mc_vis: meshcat.Visualizer, 
        coarse_aff_model: nn.Module, refine_pose_model: nn.Module, success_model: nn.Module, 
        aff_optimizer: Optimizer, pr_optimizer: Optimizer, sc_optimizer: Optimizer,
        train_dataloader: DataLoader, test_dataloader: DataLoader, 
        aff_loss_fn: Callable, pr_loss_fn: Callable, sc_loss_fn: Callable,
        dev: torch.device, 
        logger: SummaryWriter, 
        logdir: str, 
        args: config_util.AttrDict,
        start_iter: int=0,
        **kwargs):
  • 作用:定义训练主循环函数,接收可视化句柄、三大模型及其优化器、数据加载管道、三种损失函数、设备、日志写入器、输出目录、配置参数、起始迭代数,以及可选的额外模型(kwargs)。

  • 设计原因:职责分明——main 函数准备好所有资源后,一次性传给 traintrain 专注于迭代流程。

  • 接口

    • mc_vis:或 None,用于实时 3D 可视化;
    • coarse_aff_model/refine_pose_model/success_model:继承自 nn.Module,需实现 forward(...)
    • *_optimizer:标准 torch.optim.Optimizer,需实现 .zero_grad(), .step(), .load_state_dict() 等;
    • train_dataloader/test_dataloadertorch.utils.data.DataLoader,迭代返回 (coarse_aff_sample, refine_pose_sample, success_sample)
    • *_loss_fn:调用签名如 (pred, gt)→loss_tensor
    • devtorch.device('cuda:0')
    • loggertensorboard.SummaryWriter
    • logdir:字符串,如 "/home/.../exp1"
    • argsAttrDict,包含 args.experiment.batch_sizeargs.experiment.num_iterations 等;
    • start_iter:整型,恢复训练时的初始步数,默认为 0;
    • **kwargs:如双模型时包含 coarse_aff_model2 等。

1. 前期的配置

    coarse_aff_model.train()
    refine_pose_model.train()
    success_model.train()
  • 每行作用:将 PyTorch 模型切换到“训练模式”(train()),启用 BatchNorm 的更新、Dropout 生效等。
  • 示例:若模型中含 nn.Dropout(p=0.5),此时在 .train() 下,每个前向调用随机丢弃约 50% 神经元;在 .eval() 则不丢弃。
  • 为什么这么写:保证后续训练时的正则化层行为正确;如果忘记,会导致验证时与训练时分歧巨大。

    offset = np.array([0, 0.2, 0.0])
  • 作用:定义一个三维偏移向量,用于可视化或数据预处理时在 z 轴方向平移(例如把物体稍微抬高)。
  • 示例offset = [0.0, 0.2, 0.0],表示在 y 方向上抬高 0.2 米。
  • 为什么这么写:便于在 MeshCat 中把点云稍微抬高,避免与地面平面重叠导致难以辨识。

    bs = args.experiment.batch_size
  • 作用:局部变量保存 batch size,减少后面多次访问 args 的开销。
  • 示例:若配置中 batch_size=32,则 bs = 32
  • 为什么这么写:简洁易用,也清晰地将 batch size 引入局部作用域。

    it = start_iter
  • 作用:将函数参数 start_iter(恢复时的迭代步数)赋给局部变量 it,用作全局迭代计数器。
  • 示例:若恢复训练,start_iter=2000,则从第 2000 步开始;否则 start_iter=0
  • 为什么这么写:在训练循环里每处理一个 batch 都会 it += 1

    voxel_grid_pts = torch.from_numpy(train_dataloader.dataset.raster_pts).float().cuda()
  • 作用

    1. 从训练集 Dataset 对象中读取预先生成的体素格点坐标数组 raster_pts,形状为 (N,3),例如 (32768,3)
    2. 转为 torch.Tensor(dtype=float32)
    3. .cuda() 放到 GPU。
  • 示例:若 raster_pts(32³,3) 个网格点坐标,内容如 [[ -0.2, -0.2, -0.2 ], [ -0.2, -0.2, -0.192 ], …]

  • 为什么这么写:后续在前向或损失计算中需要这些网格点,可避免每个 batch 重复拷贝。


    rot_mat_grid = torch.from_numpy(train_dataloader.dataset.rot_grid).float().cuda()
  • 作用:同上,但读取旋转矩阵列表 rot_grid,形状 (R,3,3),例如 (24,3,3),然后上 GPU。
  • 示例rot_grid[0] 可能是绕 x 轴旋转 0°,rot_grid[1] 是绕 x 轴旋转 15°,等等。
  • 为什么这么写:后续生成不同旋转角度下的点云候选,需要快速批量矩阵相乘。

    args.experiment.dataset_length = len(train_dataloader.dataset)
  • 作用:将数据集大小写回配置,供日志或学习率调度使用。
  • 示例:若训练集有 10 000 条样本,则 dataset_length = 10000
  • 为什么这么写:避免手动再去测 len(dataset),集中在 args 中管理。

    if 'coarse_aff_model2' in kwargs:
  • 作用:检测是否启用了“双模型”训练模式(第二套 coarse affordance 模型),如果有,则初始化相关变量。
  • 示例:当 kwargs={'coarse_aff_model2':{...}} 时进入。

        coarse_aff_model2 = kwargs['coarse_aff_model2']['model']
        aff_optimizer2   = kwargs['coarse_aff_model2']['opt']
        aff_loss_fn2     = kwargs['coarse_aff_model2']['loss_fn']
        reso_grid2       = kwargs['coarse_aff_model2']['reso_grid']
        padding2         = kwargs['coarse_aff_model2']['padding']
  • 每行作用:分别从 kwargs['coarse_aff_model2'] 字典中取出第二模型实例、优化器、损失函数、体素分辨率和 padding。

  • 示例

    • reso_grid2 = 32(即 32³ 网格);
    • padding2 = 0.1(米)。
  • 为什么这么写:方便后续使用第二套模型的相同训练逻辑。


        voxel_grid_pts2 = three_util.get_raster_points(reso_grid2, padding=padding2)
  • 作用:调用工具函数 three_util.get_raster_points(n, padding),生成 (n³,3) 格点在世界坐标系下的坐标。

  • 输入

    • reso_grid2=32
    • padding2=0.1
  • 输出numpy.ndarray,形状 (32768,3)

  • 为什么这么写:第二模型可能有与第一模型不同的网格参数,需要单独生成。


        # reshape to grid, and swap axes (permute x and z), B x reso x reso x reso x 3
        voxel_grid_pts2 = voxel_grid_pts2.reshape(reso_grid2, reso_grid2, reso_grid2, 3)
        voxel_grid_pts2 = voxel_grid_pts2.transpose(2, 1, 0, 3)
  • 作用

    1. (N,3) 重塑为 (reso,reso,reso,3) 的立方体索引形式;
    2. transpose(2,1,0,3) 将原来的 (x,y,z) 维度调换成 (z,y,x),以符合后续函数的排列习惯。
  • 示例:若 reso_grid2=4voxel_grid_pts2.reshape(4,4,4,3)

  • 为什么这么写:某些下游算法(如体素卷积)需要特定维度顺序。


        # reshape back to B x N x 3
        voxel_grid_pts2 = torch.from_numpy(voxel_grid_pts2.reshape(-1, 3)).float().cuda()
  • 作用:把数据恢复成 (N,3) 格式并上 GPU,与 voxel_grid_pts 一致。
  • 示例:恢复成 (32768,3)
  • 为什么这么写:按 Tensor 形式并行运算更高效。

        coarse_aff_model2 = coarse_aff_model2.train()
  • 作用:同第一模型,将第二模型切换到训练模式,同时把返回值(通常是自身)赋回。
  • 为什么这么写:确保两套模型都在 train() 模式下。

2. 进入主训练循环

    while True:
  • 作用:开启一个无限循环,内部以迭代次数 it 控制退出。
  • 为什么这么写:可灵活在中途 break,并在循环内部集中做各种训练步骤。

        if it > args.experiment.num_iterations:
            break
  • 作用:当全局迭代 it 超过配置的最大迭代数 num_iterations(如 100000)时,退出循环。
  • 示例:若 num_iterations=50000it=50001,则结束训练。
  • 为什么这么写:控制训练总步数。

2.1 可选数据调试
        if args.debug_data:
            # sample = train_dataloader.dataset[0]
            sample = train_dataloader.dataset[371]
            # sample = train_dataloader.dataset[1963]
            print('[Debug Data] Here with sample')

            # for i in range(len(train_dataloader.dataset)):
            #     sample = train_dataloader.dataset[i]
            #     if 'parent_start_pcd' not in sample[1][0].keys():
            #         print(f'[Debug Data] Here with bad sample (index: {i})')
            #         from IPython import embed; embed()

            from IPython import embed; embed()
  • 作用:当开启 --debug_data 时,从数据集中取固定索引(371)样本,打印调试信息并进入 IPython 交互。
  • 示例:这样可以在训练前检查某个样本是否有缺失字段,或可视化它。
  • 为什么这么写:快速定位哪条数据有问题,避免整个训练流程卡在某个 batch 上。

        for sample in train_dataloader:
  • 作用:遍历 DataLoader 返回的每个 batch,sample 通常是三元组 (coarse_aff_sample, refine_pose_sample, success_sample)
  • 示例:若 batch size=32,则每次 sample 中每个子项的数据量为 32 条。
  • 为什么这么写:标准 PyTorch 训练循环写法。

            it += 1
  • 作用:迭代步数自增 1,用于日志、学习率调度、Checkpoint 保存等。
  • 示例:第一次循环时 it 从 0→1,第二次 1→2。
  • 为什么这么写:准确记录训练进度。
            current_epoch = it * bs / len(train_dataloader.dataset)
  • 作用:动态计算当前“epoch”进度:已训练样本数 it*bs 除以总样本数。
  • 示例:如果 it=156bs=32dataset_length=10000,则 current_epoch = 156*32/10000 ≈ 0.4992,即接近第 0.5 个 epoch。
  • 为什么这么写:便于在 TensorBoard 中横坐标以 epoch 为单位显示。
            start_time = time.time()
  • 作用:记录本次 batch 前的时间戳,用于计算本次前向+反向用时。
  • 示例start_time = 1700000000.123(Unix 时间)。
  • 为什么这么写:性能监控,判断每个 batch 速度。

            coarse_aff_sample, refine_pose_sample, success_sample = sample
  • 作用:解包 sample,分别得到三组任务数据。

  • 示例

    • coarse_aff_sample = (inputs_dict, gt_dict)
    • refine_pose_sample = (inputs_dict, gt_dict)
    • success_sample = (inputs_dict, gt_dict)
  • 为什么这么写:三任务并行,可以选择性地训练。

            coarse_aff_mi, coarse_aff_gt = coarse_aff_sample
            refine_pose_mi, refine_pose_gt = refine_pose_sample
            success_mi, success_gt       = success_sample
  • 作用:进一步解包——mi 表示模型输入字典(model input),gt 表示 ground-truth 字典。

  • 示例

    coarse_aff_mi = {'pointcloud':Tensor(32,1024,3), 'voxel_centers':Tensor(32768,3), …}
    coarse_aff_gt = {'aff_labels':Tensor(32,32768), …}
    
  • 为什么这么写:保持统一命名,后续可写 dict_to_gpu(coarse_aff_mi)

            coarse_aff_out = None
            refine_pose_out = None
            success_out = None
  • 作用:先把输出占位为 None,以便后面条件分支后检查是否生成。
  • 为什么这么写:清晰标识三种输出可能未被某些分支产生。
            loss_dict = {}
  • 作用:初始化一个空字典,后续把各任务的损失值收集进来,用于一次性日志打印。
  • 为什么这么写:统一管理不同子任务的损失输出。

2.2 训练粗可用性子任务
            if args.experiment.train.train_coarse_aff and (len(coarse_aff_mi) > 0):
  • 作用:仅当配置开启该子任务,且输入非空时才执行该块。
  • 示例args.experiment.train.train_coarse_aff=True,且 coarse_aff_mi 包含若干点云,则进入。
                # prepare input and gt
                coarse_aff_mi = dict_to_gpu(coarse_aff_mi)
                coarse_aff_gt = dict_to_gpu(coarse_aff_gt)
  • 作用:调用工具函数 dict_to_gpu,将字典中所有 Tensor 从 CPU 移到 GPU。

  • 接口

    def dict_to_gpu(d: Dict[str,Tensor]) -> Dict[str,Tensor]:
        return {k:v.cuda() for k,v in d.items()}
    
  • 为什么这么写:模型在 GPU 上运算,输入数据必须位于同一设备。

                coarse_aff_out = train_iter_coarse_aff(
                    coarse_aff_mi,
                    coarse_aff_gt,
                    coarse_aff_model,
                    aff_optimizer,
                    aff_loss_fn,
                    args,
                    voxel_grid_pts, args.data.voxel_grid.reso_grid,
                    rot_mat_grid, args.data.rot_grid_bins,
                    it, current_epoch,
                    logger, 
                    mc_vis=mc_vis)
  • 作用:调用“单次训练迭代”函数 train_iter_coarse_aff,完成一次前向、损失、反向、优化并可视化。

  • 接口

    def train_iter_coarse_aff(
        inputs:Dict, gt:Dict,
        model:nn.Module,
        optimizer:Optimizer,
        loss_fn:Callable,
        args:AttrDict,
        voxel_pts:Tensor, reso:int,
        rot_mats:Tensor, rot_bins:int,
        iter_num:int, epoch:float,
        logger:SummaryWriter,
        mc_vis:meshcat.Visualizer=None,
        refine_pred:bool=False
    ) -> Dict[str,Any]:
    
    • 输入:如上;

    • 输出coarse_aff_out 是字典,包含

      {
        'model_output': Tensor,   # 原始网络预测
        'loss': {'aff_loss': Tensor(...), ...},
        'viz_data': {...}         # 可视化所需
      }
      
  • 为什么这么写:把核心训练逻辑封装,train 保持清晰。

                # process output for logging
                for k, v in coarse_aff_out['loss'].items():
                    loss_dict[k] = v
  • 作用:遍历该子任务返回的所有损失项(可能有主损失 + 辅助损失),收集到 loss_dict,后续统一打印/写入 TensorBoard。
  • 示例:可能有 {'aff_loss': Tensor(0.123), 'rot_loss': Tensor(0.045)}

2.2.1 结合粗预测生成 refine 输入
                if args.experiment.train.coarse_aff_from_coarse_pred:
                	# process coarse prediction and refine input to create new refine input
  • 作用:若配置要求用本次粗预测结果再做一次粗预测,则进入。
                    if args.model.coarse_aff.multi_model:
                        aff_refine_model      = coarse_aff_model2
                        aff_refine_opt        = aff_optimizer2
                        aff_refine_loss_fn    = aff_loss_fn2
                        aff_refine_voxel_grid_pts = voxel_grid_pts2
                        aff_refine_reso_grid  = reso_grid2
                    else:
                        aff_refine_model      = coarse_aff_model
                        aff_refine_opt        = aff_optimizer
                        aff_refine_loss_fn    = aff_loss_fn
                        aff_refine_voxel_grid_pts = voxel_grid_pts
                        aff_refine_reso_grid  = args.data.voxel_grid.reso_grid
  • 作用:根据是否启用双模型,选择做 refine 的那一套模型、参数和网格。
                    coarse_aff_mi, coarse_aff_gt = coarse_aff_to_coarse_aff(
                        coarse_aff_mi, coarse_aff_out['model_output'], coarse_aff_gt,
                        rot_mat_grid, voxel_grid_pts, aff_refine_voxel_grid_pts, aff_refine_reso_grid, args, mc_vis=mc_vis)
  • 作用:调用转换函数coarse_aff_to_coarse_aff,将粗预测结果 model_output 结合原始输入/标签,生成新的输入/标签对,用于二次迭代。

  • 接口

    def coarse_aff_to_coarse_aff(
        mi:Dict, pred:Tensor, gt:Dict,
        rot_mats:Tensor, voxel_pts1:Tensor,
        voxel_pts2:Tensor, reso2:int,
        args:AttrDict, mc_vis:Visualizer=None
    ) -> Tuple[Dict,Dict]:
        # 1. 根据 pred 生成新的体素网格预测标签
        # 2. 构建新的 mi/gt 字典
    
  • 为什么这么写:尝试多次迭代 refine,提升可用性预测精度。

                    coarse_aff_mi = dict_to_gpu(coarse_aff_mi)
                    coarse_aff_gt = dict_to_gpu(coarse_aff_gt)
  • 作用:同上,把新生成的输入/标签搬到 GPU。
                    coarse_aff_out = train_iter_coarse_aff(
                        coarse_aff_mi,
                        coarse_aff_gt,
                        aff_refine_model,
                        aff_refine_opt,
                        aff_refine_loss_fn,
                        args,
                        aff_refine_voxel_grid_pts, aff_refine_reso_grid,
                        rot_mat_grid, args.data.rot_grid_bins,
                        it, current_epoch,
                        logger, refine_pred=True,
                        mc_vis=mc_vis)
  • 作用:再跑一次 train_iter_coarse_aff,并在输出中标记 refine_pred=True 以区别首轮与二轮迭代。
  • 为什么这么写:同上。

2.2.2 从粗预测生成 refine_pose 输入
                if args.experiment.train.refine_pose_from_coarse_pred:
                	# process coarse prediction and refine input to create new refine input
                    refine_pose_mi, refine_pose_gt = coarse_aff_to_refine_pose(
                        coarse_aff_mi, coarse_aff_out['model_output'], coarse_aff_gt,
                        refine_pose_mi, refine_pose_gt, 
                        rot_mat_grid, voxel_grid_pts, args, mc_vis=mc_vis)
  • 作用:如果配置要求,用粗预测直接生成给微调网络的新输入/标签,调用 coarse_aff_to_refine_pose

  • 接口

    def coarse_aff_to_refine_pose(
        coarse_mi, coarse_pred, coarse_gt,
        prev_refine_mi, prev_refine_gt,
        rot_mats, voxel_pts, args, mc_vis=None
    ) -> Tuple[new_refine_mi,new_refine_gt]:
        # 1. 从 coarse_pred 提取最可能的位姿候选
        # 2. 更新 refine_pose 的输入/标签
    
  • 为什么这么写:连贯地将两个任务串联,训练过程中让 refine 网络看到最新的粗预测。


2.2.3 从粗预测生成 success 输入
                if args.experiment.train.success_from_coarse_pred:
                	# process coarse prediction and success input to create new success input
                    success_mi, success_gt = coarse_aff_to_success(
                        coarse_aff_mi, coarse_aff_out['model_output'], coarse_aff_gt, 
                        success_mi,  success_gt, 
                        rot_mat_grid, voxel_grid_pts, args, mc_vis=mc_vis)
  • 作用:同样地,用粗预测生成成功判断子任务的输入/标签,调用 coarse_aff_to_success
  • 为什么这么写:通过粗预测结果辅助训练成功分类器。

2.3 训练微调(Refine Pose)子任务
            if args.experiment.train.train_refine_pose and (len(refine_pose_mi) > 0):
  • 作用:若打开该子任务且输入非空,则进入。
				# prepare input and gt
                refine_pose_mi = dict_to_gpu(refine_pose_mi)
                refine_pose_gt = dict_to_gpu(refine_pose_gt)
  • 作用:上 GPU。
                refine_pose_out = train_iter_refine_pose(
                    refine_pose_mi,
                    refine_pose_gt,
                    refine_pose_model,
                    pr_optimizer,
                    pr_loss_fn,
                    args,
                    it, current_epoch,
                    logger,
                    mc_vis=mc_vis)
  • 接口

    def train_iter_refine_pose(
        mi:Dict, gt:Dict,
        model:nn.Module,
        optimizer:Optimizer,
        loss_fn:Callable,
        args:AttrDict,
        iter_num:int, epoch:float,
        logger:SummaryWriter,
        mc_vis:Visualizer=None
    ) -> Dict[str,Any]:
        # 1. optimizer.zero_grad()
        # 2. pred = model(mi)
        # 3. loss = loss_fn(pred, gt)
        # 4. loss.backward()
        # 5. optimizer.step()
        # 6. logger.add_scalar...
        # 7. 可视化(mc_vis)
        # 8. return {'model_output':pred, 'loss':{'pose_loss':...}}
    
  • 为什么这么写:同样封装迭代细节。

				# process output for logging
                for k, v in refine_pose_out['loss'].items():
                    loss_dict[k] = v
  • 作用:收集该子任务的所有损失项到 loss_dict
                if args.experiment.train.success_from_refine_pred:
                	# process coarse prediction and success input to create new success input
                    success_mi, success_gt = refine_pose_to_success(
                        refine_pose_mi, refine_pose_out['model_output'], refine_pose_gt, 
                        success_mi, success_gt, 
                        args, mc_vis=mc_vis) 
  • 作用:若开启“从 refine 输出预测 success”,调用 refine_pose_to_success,生成 success 输入/标签。

2.4 训练成功分类(Success)子任务
            if args.experiment.train.train_success and (len(success_mi) > 0):
  • 作用:若配置开启且输入非空,则进入。
                if args.experiment.train.success_from_refine_pred and (len(refine_pose_mi) == 0):
                    print(f'Skipping success due to no refine pose model input')
                    continue
  • 作用:若配置要求从 refine 预测生成 success,但 refine 模块未提供输入(长度为 0),则跳过本 batch 的 success 部分,继续下一个 batch。
  • 为什么这么写:防止空输入导致报错。
				# prepare input and gt
                success_mi = dict_to_gpu(success_mi)
                success_gt = dict_to_gpu(success_gt)
  • 作用:上 GPU。
                success_out = train_iter_success(
                    success_mi,
                    success_gt,
                    success_model,
                    sc_optimizer,
                    sc_loss_fn,
                    args,
                    it, current_epoch,
                    logger,
                    mc_vis=mc_vis)
  • 接口

    def train_iter_success(
        mi:Dict, gt:Dict,
        model:nn.Module,
        optimizer:Optimizer,
        loss_fn:Callable,
        args:AttrDict,
        iter_num:int, epoch:float,
        logger:SummaryWriter,
        mc_vis:Visualizer=None
    ) -> Dict[str,Any]:
        # 同上结构
    
  • 为什么这么写:与前两者保持一致、易于维护。

				# process output for logging
                for k, v in success_out['loss'].items():
                    loss_dict[k] = v
  • 作用:收集 success 子任务的损失项。

2.5 日志打印 与 Checkpoint 保存
            if it % args.experiment.log_interval == 0 and args.experiment.train.out_log_full:
  • 作用:每隔 log_interval 步(如 100 步)且开启完整日志输出时,执行打印和 TensorBoard 写入。
  • 示例:若 log_interval=50,当 it=50,100,150… 时触发。
                string = f'Iteration {it} -- '
  • 作用:初始化日志字符串。
                for loss_name, loss_val in loss_dict.items():
                    if isinstance(loss_val, dict):
                        # don't let these loss dicts get more than two levels deep
                        for k, v in loss_val.items():
                            string += f'{k}: {v.mean().item():.6f} '
                            logger.add_scalar(k, v.mean().item(), it)
                    else:
                        string += f'{loss_name}: {loss_val.mean().item():.6f} '
                        logger.add_scalar(loss_name, loss_val.mean().item(), it)
  • 作用

    1. 遍历 loss_dict,支持两层嵌套(如某些子损失又返回字典)
    2. 对每个张量 v.mean().item(),得到 Python float,如 0.123456
    3. 拼入日志字符串;
    4. logger.add_scalar(name, value, it) 写入 TensorBoard。
  • 示例

    Iteration 100 -- aff_loss: 0.123456 pose_loss: 0.045678 success_loss: 0.210000 
    
                if args.experiment.debug:
                    from IPython import embed; embed()
  • 作用:若开启 debug 模式,进入交互式终端,方便现场调试中断点。
  • 为什么这么写:实用的调试手段。
                end_time = time.time()
                total_duration = end_time - start_time
  • 作用:计算本次 batch 的总耗时,如 0.2345 秒。
                string += f'duration: {total_duration:.4f}'
                print(string)
  • 作用:将耗时拼接进入日志并打印到屏幕,如

    Iteration 100 -- aff_loss:0.123456 pose_loss:0.045678 success_loss:0.210000 duration:0.2345
    

            if it % args.experiment.save_interval == 0 and it > 0:
  • 作用:每隔 save_interval 步(如 1000)且 it>0 时,保存模型 checkpoint。
  • 示例:当 save_interval=500it=500,1000,1500… 时触发。
                model_path = osp.join(logdir, f'model_{it}.pth')
                model_path_latest = osp.join(logdir, 'model_latest.pth')
  • 作用:拼出当前迭代号和“最新”两个文件名,前者用于版本管理,后者始终指向最新模型。
                ckpt = {'args': config_util.recursive_dict(args)}
  • 作用:新建 checkpoint 字典,先保存配置快照。
  • 为什么这么写:确保重现训练环境。
                ckpt['coarse_aff_model_state_dict']    = coarse_aff_model.state_dict()
                ckpt['refine_pose_model_state_dict']   = refine_pose_model.state_dict()
                ckpt['success_model_state_dict']       = success_model.state_dict()
  • 作用:分别保存三大模型的权重字典(dict[str,Tensor])。
                ckpt['aff_optimizer_state_dict']       = aff_optimizer.state_dict()
                ckpt['pr_optimizer_state_dict']        = pr_optimizer.state_dict()
                ckpt['sc_optimizer_state_dict']        = sc_optimizer.state_dict()
  • 作用:保存对应优化器的内部状态(包括动量、学习率调度器步数等)。
                if args.model.coarse_aff.multi_model:
                    ckpt['coarse_aff_model_state_dict2'] = coarse_aff_model2.state_dict()
                    ckpt['aff_optimizer_state_dict2']   = aff_optimizer2.state_dict()
  • 作用:若启用双模型,再保存第二模型及其优化器状态。
                torch.save(ckpt, model_path)
                torch.save(ckpt, model_path_latest)
  • 作用:将 ckpt 序列化存盘,两个路径各写一份。
  • 为什么这么写:保留历史版本同时更新“最新”链接,便于恢复最新或指定版本。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值