iDP3的Learning代码解析:逐步分解iDP3的数据集、模型、动作预测策略代码(包含2D和3D两个版本)

前言

今25年1.14日起,我和同事孙老师连续出差苏州、无锡、南京、上海

  • 1.14日在苏州,一家探讨人形合作研发,一家是客户
  • 1.15-1.16两天在南京,和同事姚博士、合作商一块接待一机器人集团客户
    客户表示高校偏科研,但我们做到了科研与落地并重,很希望合作——主动提出拉群保持逐月推进
  • 1.17日在无锡,参观一集团工厂、交流可合作开发的业务场景,并约定年后再去一趟电器厂
  • 1.18日则在上海约了4位,分别来自两人形公司、一国家级实验室、一大模型独角兽

我们再次感慨,绝大部分工厂都将在今2025年开始做一系列智能升级、智能改造

  1. 而背后用的策略方法,也将从传统的深度学习方法,往大模型 + 模仿学习 + RL方面迁移,这是势不可挡的大趋势
    至于目前我们具体到底用的啥策略/架构,取决于具体的场景或任务,比如(只是比如)此文《斯坦福通用人形策略iDP3——同一套策略控制各种机器人:改进3D扩散策略,不再依赖相机校准和点云分割(含DP3的详解)》所述的iDP3

    毕竟,“ 截止到25年1.12日,我们idp3的复现迎来大进展,idp3架构拆解完了,且还弄了一个通用架构:可以同时跑跑umi、dexcap、dp3、ipd3这4个模型
  2. 因此在出差的间隙,我于1.15日~1.16日把人形动作预测策略——ipd3源码的所有代码文件整体看了下,确实如姚博士所说,模块清晰 各司其职

    本想着​这几天出差完后 把ipd3的源码也做下解读,想了下,只要有时间空闲,我便开始解读吧
    包括1.17日从无锡来上海的路上——高铁上 酒店大堂里 网约车上 餐厅里,我都拿出了MacBook Pro修订本篇《iDP3源码剖析》博客
    可能这就是为何每次出差,和做AI 大模型 具身的技术人交流时,十之八九都看过我博客(不管在哪个TOP高校 不管在哪个大厂)的原因吧,背后毕竟有着十多年的积累​​​」

于此,今天便有了本文「注意,看本文之前,建议先通过此文了解 iDP3的原理」,且重点分析其learning的代码:Improved-3D-Diffusion-Policy,至于部署代码的分析见此文iDP3的训练与部署代码解析:从数据可视化vis_dataset.py、训练脚本train.py到部署脚本deploy.py

而为了让本文的源码剖析足够清晰,我是花了不少心思的,因为源码分析其实很容易变成各种堆砌代码——相信 这种堆砌 大家也看多了,所以我特意做了以下这几点措施

  1. 每一段待解读的代码,尽可能控制在10行以内,因为按我的经验,超过10行 看着就累了
  2. 即便有解读,贴的代码 也要逐行都有对应的注释
    因为这样 可以更加一目了然
  3. 为了随时让读者知道某个被分析的函数处在哪个文件夹下,以及在整体中的位置及与前后代码文件的关联
    对于较长的代码文件,我会特意在分析代码文件之前,贴一下对应的代码结构截图
    如此,还是为了一目了然
  4. 每个章节的代码文件名称都加上了对应的一句话说明,这样让大家对“被分析的代码文件是具体干什么的”可以一目了然,且更让整个目录更有全局感,更清晰 

目录

前言

第一部分 数据集与配置:diffusion_policy_3d的common、config、dataset、workspace

1.1 diffusion_policy_3d/common

1.1.1 common/gr1_action_util.py:转换和处理与关节和末端执行器EEF相关的数据

1.1.3 common/multi_realsense.py:管理和处理多个 RealSense 摄像头的数据流

1.2 diffusion_policy_3d/config:决定是用2D策略还是3D策略

1.2.1 config/task:gr1_dex-3d.yaml(相当于具体的任务)、gr1_dex-image.yaml

1.2.2 config/dp_224x224_r3m.yaml

1.2.3 config/idp3.yaml:相当于配置文件

1.3 diffusion_policy_3d/workspace

1.3.1 workspace/base_workspace.py

1.3.2 workspace/dp_workspace.py

1.3.3 workspace/idp3_workspace.py:相当于实现文件(很重要)

1.4 diffusion_policy_3d/dataset:各种数据集及相关处理

1.4.1 dataset/base_dataset.py:低维、图像、点云、通用等4类数据集

1.4.2 dataset/gr1_dex_dataset_3d.py:处理 3D 数据集

1.4.3 dataset/gr1_dex_dataset_image.py:处理图像和深度信息

第二部分 扩散模型与3D点云编码器的实现:diffusion_policy_3d/model

2.1 model/common

2.2 model/diffusion:涉及五个代码文件

2.2.1 diffusion/conditional_unet1d.py:再涉及3个代码子文件,分别实现交叉注意力、条件残差块、条件U-Net 网络

2.2.1.1 CrossAttention 类:实现交叉注意力

2.2.1.2 ConditionalResidualBlock1D 类:条件残差块,在一维卷积网络中实现条件处理

2.2.1.3 ConditionalUnet1D:条件一维 U-Net 网络,在一维数据上实现条件生成任务

2.2.2 diffusion/conv1d_components.py:涉及一维卷积、下采样、上采样

2.2.3 diffusion/ema_model.py:实现模型权重的指数移动平均EMA

2.2.4 diffusion/mask_generator.py

2.2.5 diffusion/positional_embedding.py:为输入数据添加位置信息

2.3 model/vision

2.4 model/vision_3d

2.4.1 vision_3d/multi_stage_pointnet.py:对点云数据进行编码

2.4.2 vision_3d/point_process.py:针对点云的打乱/填充/采样操作(含NumPy和PyTorch实现)

2.4.3 vision_3d/pointnet_extractor.py:包含点云编码器iDP3Encoder的实现

第三部分 基于图像和点云的扩散策略:diffusion_policy_3d/policy(相当于包含2D和3D两个版本)

3.1 policy/base_policy.py:基类策略模型

3.2 2D版本——policy/diffusion_image_policy.py:基于图像的扩散策略

3.2.1 __init__

3.2.2 forward:根据输入的观察字典 obs_dict 生成动作

3.2.3 conditional_sample:给定条件下的采样

3.2.4 predict_action:根据输入的观察字典obs_dict预测动作

3.2.5 compute_loss:计算给定批次数据的损失

3.3 3D版本——policy/diffusion_pointcloud_policy.py:基于点云的扩散策略(与3.2节有相似)

3.3.1 __init__

3.3.2 forward:根据输入的观察字典 obs_dict 生成动作

3.3.3 conditional_sample:在给定条件下进行采样

3.3.4 predict_action:根据输入的观察字典 obs_dict 生成动作(与forward类似)

3.3.5 compute_loss:计算给定批次数据的损失


第一部分 数据集与配置:diffusion_policy_3d的common、config、dataset、workspace

1.1 diffusion_policy_3d/common

本common文件夹下 有一些代码文件 暂未解读,比如checkpoint_util.py——该类用于管理模型训练过程中的检查点(checkpoint),确保只保留性能最好的k个检查点

1.1.1 common/gr1_action_util.py:转换和处理与关节和末端执行器EEF相关的数据

该代码片段主要用于转换和处理与机器人关节和末端执行器EEF相关的数据

  1. 首先,导入了numpy、torch以及自定义的rotation_util模块,并定义了若干初始姿态与位置变量(init_arm_pos、init_arm_quat等)
  2. joint32_to_joint25函数将包含32个关节数据的数组转换为只包含25个关节数据的数组,主要通过选择和映射腰部、头部、手臂与手的关节索引
  3. joint25_to_joint32函数则执行反向操作,将25个关节数据填充回32个
  4. extract_eef_action函数从传入的eef_action向量中提取身体动作、双臂位置和旋转,以及手部动作
    这里的手臂旋转采用6D表示法,用rotation_util模块可进一步转换至四元数
  5. 最后,extract_abs_eef 函数基于增量位置和旋转,计算得到新的绝对位置和旋转。它会先将四元数转换至6D旋转进行相加,再通过rotation_util还原为新的四元数,以便完整表达最终的末端执行器位姿

1.1.3 common/multi_realsense.py:管理和处理多个 RealSense 摄像头的数据流

该代码片段主要用于管理和处理多个 RealSense 摄像头的数据流。它包括初始化摄像头、获取摄像头数据、处理点云数据等功能

首先,导入各个库和定义各个函数

具体而言

  1. 首先,导入了必要的库和模块,包括 multiprocessing、`numpy` 和 `pyrealsense2`
    设置了多进程的启动方法为 `fork`,并配置了 numpy 的打印选项
  2. get_realsense_id 函数用于获取连接到系统的所有 RealSense 摄像头的序列号,并返回这些序列号的列表
  3. init_given_realsense 函数用于初始化指定的 RealSense 摄像头
    它接受多个参数,包括设备序列号、是否启用 RGB 和深度流、是否启用点云、同步模式
    根据这些参数配置摄像头,并返回摄像头的管道、对齐对象、深度比例和相机信息
  4. grid_sample_pcd 函数用于对点云数据进行网格采样。它接受一个点云数组和网格大小,返回采样后的点云数组

其次,CameraInfo 类用于存储相机的内参信息,包括宽度、高度、焦距、主点坐标和比例

接下来,SingleVisionProcess 类继承自 Process,用于管理单个摄像头的数据流。它在初始化时接受多个参数,包括设备序列号、队列、是否启用 RGB 和深度流、是否启用点云、同步模式、点云数量、远近裁剪距离、是否使用网格采样和图像大小

def __init__(self, device, queue,      # 初始化方法,接受设备和队列作为参数
                enable_rgb=True,       # 是否启用 RGB 流,默认值为 True
                enable_depth=False,          # 是否启用深度流,默认值为 False
                enable_pointcloud=False,     # 是否启用点云,默认值为 False
                sync_mode=0,              # 同步模式,默认值为 0
                num_points=2048,          # 点云数量,默认值为 2048
                z_far=1.0,           # 远裁剪距离,默认值为 1.0
                z_near=0.1,          # 近裁剪距离,默认值为 0.1
                use_grid_sampling=True,      # 是否使用网格采样,默认值为 True
                img_size=224) -> None:       # 图像大小,默认值为 224

类中定义了 

  1. get_vision 方法用于获取摄像头数据
  2. run 方法用于启动摄像头数据流
  3. terminate 方法用于终止数据流
  4. create_colored_point_cloud 方法用于创建带颜色的点云

最后,MultiRealSense 类用于管理多个 RealSense 摄像头。它在初始化时接受多个参数,包括是否使用前置和右侧摄像头、摄像头索引、点云数量、远近裁剪距离、是否使用网格采样和图像大小,详见如下

# 初始化方法,接受多个参数,默认使用前置摄像头,不使用右侧摄像头
def __init__(self, use_front_cam=True, use_right_cam=False,  
                 # 前置摄像头和右侧摄像头的索引,默认值分别为 0 和 1
                 front_cam_idx=0, right_cam_idx=1,  

                 # 前置摄像头和右侧摄像头的点云数量,默认值分别为 4096 和 1024
                 front_num_points=4096, right_num_points=1024,  

                 # 前置摄像头的远近裁剪距离,默认值分别为 1.0 和 0.1
                 front_z_far=1.0, front_z_near=0.1,  

                 # 右侧摄像头的远近裁剪距离,默认值分别为 0.5 和 0.01
                 right_z_far=0.5, right_z_near=0.01,  

                 use_grid_sampling=True,  # 是否使用网格采样,默认值为 True
                 img_size=384):  # 图像大小,默认值为 384

类中定义了

  1. _call方法,用于获取摄像头数据
  2. finalize方法,用于终止所有摄像头的数据流
  3. _del_方法用于在对象销毁时调用finalize方法

通过这些类和函数,代码实现了对多个 RealSense 摄像头的数据管理和处理,适用于需要同时处理多个摄像头数据的应用场景

1.2 diffusion_policy_3d/config:决定是用2D策略还是3D策略

这几个配置文件,在训练和部署的时候会用到,详见此文《iDP3的训练与部署代码解析(含预处理):包含训练脚本train.py、以及适配各种机械臂的通讯脚本deploy.py》中介绍的train.py、deploy.py

1.2.1 config/task:gr1_dex-3d.yaml(相当于具体的任务)、gr1_dex-image.yaml

  1. gr1_dex-3d.yaml
    这个 YAML 文件定义了一个名为 `box` 的配置。它包含了形状元数据和数据集的详细信息
    name: box
    在 `shape_meta` 部分,定义了观测和动作的形状和类型——相当于在任务文件中定义了机器人的输入、输出
    观测 (`obs`) 包括两个部分:`point_cloud` 和 `agent_pos`
    `point_cloud` 的形状是 `[4096, 6]`,类型是 `point_cloud`,表示一个点云数据
      obs:
        point_cloud:
          shape: [4096, 6]
          type: point_cloud
    `agent_pos` 的形状是 `[32]`,类型是 `low_dim`,表示低维度数据
        agent_pos:
          shape: [32]
          type: low_dim
    动作 (`action`) 的形状是 `[25]`
      action:
        shape: [25]
    在 `dataset` 部分,定义了数据集的相关配置
    `_target_` 指定了数据集的目标类 `diffusion_policy_3d.dataset.gr1_dex_dataset_3d.GR1DexDataset3D`
      _target_: diffusion_policy_3d.dataset.gr1_dex_dataset_3d.GR1DexDataset3D
    `zarr_path` 指定了数据集的路径 box_zarr

    `horizon`、`pad_before` 和 `pad_after` 是一些动态计算的参数,分别表示时间范围、前填充和后填充
    `seed` 设置为 42,用于随机数生成。`val_ratio` 设置为 0.00,表示没有验证集。`max_train_episodes` 设置为 90,表示最大训练集数。`num_points` 是一个动态参数,取自 `policy.pointcloud_encoder_cfg.num_points`

    总体来说,这个配置文件定义了一个用于 3D 点云数据处理和训练的配置,包含了数据的形状、类型以及数据集的路径和相关参数
  2. gr1_dex-image.yaml
    // 待更

1.2.2 config/dp_224x224_r3m.yaml

// 待更

1.2.3 config/idp3.yaml:相当于配置文件

这个 YAML 文件定义了一个名为 `train_diffusion_unet_hybrid` 的训练配置,目标是使用 3D 点云数据进行扩散模型的训练——可以自动调用对应的workspace文件、policy文件

  1. 在 `defaults` 部分,指定了任务 `dexdeform_flip_pointcloud`,更多如下
    defaults:
      - task: dexdeform_flip_pointcloud       # 默认任务配置
    
    name: train_diffusion_unet_hybrid         # 实验名称
    _target_: diffusion_policy_3d.workspace.idp3_workspace.iDP3Workspace # 目标工作空间类
    
    task_name: ${task.name}             # 任务名称
    shape_meta: ${task.shape_meta}      # 形状元数据
    exp_name: "debug"                   # 实验名称
  2. `n_obs_steps` 和 `horizon` 分别设置为 2 和 16,表示观测步数和时间范围,更多如下
    # n_obs_steps: 2
    # n_obs_steps: 1
    n_obs_steps: 2          # 观测步数
    
    # horizon: 4
    # n_action_steps: 4
    
    # horizon: 16
    # n_action_steps: 15
    
    horizon: 16             # 时间范围
    n_action_steps: 15      # 动作步数
    
    n_latency_steps: 0      # 延迟步数
    dataset_obs_steps: ${n_obs_steps}      # 数据集观测步数
    keypoint_visible_rate: 1.0             # 关键点可见率
    obs_as_global_cond: True               # 观测作为全局条件
    
    use_image: false         # 不使用图像
  3. `policy` 部分定义了策略的详细配置,目标类为 `diffusion_policy_3d.policy.diffusion_pointcloud_policy.DiffusionPointcloudPolicy`——详见下文的《3.3 3D版本——policy/diffusion_pointcloud_policy.py:基于点云的扩散策略(与3.2节有相似)》
    policy:
      # 策略目标类
      _target_: diffusion_policy_3d.policy.diffusion_pointcloud_policy.DiffusionPointcloudPolicy 
    包括使用点裁剪、下中上条件、扩散步骤嵌入维度、降维、核大小、组数等配置
      use_point_crop: true         # 使用点裁剪
      use_down_condition: true     # 使用下条件
      use_mid_condition: true      # 使用中条件
      use_up_condition: true       # 使用上条件
      use_image: false             # 不使用图像
      
      diffusion_step_embed_dim: 128         # 扩散步骤嵌入维度
      down_dims: [256, 512, 1024]           # 降维
    
      horizon: ${horizon}                   # 时间范围
      kernel_size: 5                        # 核大小
      n_action_steps: ${n_action_steps}     # 动作步数
      n_groups: 8                           # 组数
      n_obs_steps: ${n_obs_steps}           # 观测步数
    `noise_scheduler` 部分定义了噪声调度器的配置,目标类为 `diffusers.schedulers.scheduling_ddim.DDIMScheduler`
      noise_scheduler:
        _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler # 噪声调度器目标类
        num_train_timesteps: 50                  # 训练时间步数
        beta_start: 0.0001                       # beta 起始值
        beta_end: 0.02                           # beta 结束值
        beta_schedule: squaredcos_cap_v2         # beta 调度
        clip_sample: True                        # 剪辑样本
        set_alpha_to_one: True                   # 设置 alpha 为 1
        steps_offset: 0                          # 步数偏移
        prediction_type: sample                  # 预测类型
    
      num_inference_steps: 10                    # 推理步数
    
      obs_as_global_cond: true                   # 观测作为全局条件
      shape_meta: ${shape_meta}                  # 形状元数据
    
      use_pc_color: false                        # 不使用点云颜色
      pointnet_type: "multi_stage_pointnet"      # PointNet 类型
    
      point_downsample: true                     # 点下采样
    
      pointcloud_encoder_cfg:
        in_channels: 3              # 输入通道数
        out_channels: 128           # 输出通道数
        use_layernorm: true         # 使用层归一化
        final_norm: layernorm       # 最终归一化类型
        normal_channel: false       # 正常通道
        num_points: 4096            # 点数
  4. `ema` 部分定义了指数移动平均模型的配置,目标类为 `diffusion_policy_3d.model.diffusion.ema_model.EMAModel`
    ema:
      # EMA 模型目标类
      _target_: diffusion_policy_3d.model.diffusion.ema_model.EMAModel 
      update_after_step: 0       # 步后更新
      inv_gamma: 1.0             # 逆伽马
      power: 0.75                # 功率
      min_value: 0.0             # 最小值
      max_value: 0.9999          # 最大值
    `dataloader` 和 `val_dataloader` 部分定义了数据加载器的配置,包括批量大小、工作线程数、是否打乱数据等
    dataloader:
      # batch_size: 120
      batch_size: 64         # 批量大小
      num_workers: 8         # 工作线程数
      shuffle: True          # 是否打乱数据
      pin_memory: True       # 固定内存
      persistent_workers: False     # 持久化工作线程
    
    val_dataloader:
      # batch_size: 120
      batch_size: 64         # 批量大小
      num_workers: 8         # 工作线程数
      shuffle: False         # 是否打乱数据
      pin_memory: True       # 固定内存
      persistent_workers: False     # 持久化工作线程
  5. `optimizer` 部分定义了优化器的配置,目标类为 `torch.optim.AdamW`,包括学习率、动量、权重衰减等
    optimizer:
      # 优化器目标类
      _target_: torch.optim.AdamW 
      lr: 1.0e-4                   # 学习率
      betas: [0.95, 0.999]         # beta 参数
      eps: 1.0e-8                  # epsilon 参数
      weight_decay: 1.0e-6         # 权重衰减
    `training` 部分定义了训练的详细配置,包括设备、随机种子、调试模式、学习率调度器、训练周期数、梯度累积、EMA 使用、检查点保存等
    training:
      device: "cuda:0"         # 设备
      seed: 42                 # 随机种子
      debug: False             # 调试模式
      resume: True             # 恢复训练
      lr_scheduler: cosine     # 学习率调度器
      lr_warmup_steps: 500     # 学习率预热步数
      num_epochs: 301                   # 训练周期数    
      gradient_accumulate_every: 1      # 梯度累积步数
      use_ema: True                     # 使用 EMA
      rollout_every: 400                # 每隔多少步 rollout
      checkpoint_every: 100             # 每隔多少步保存检查点
      val_every: 100                    # 每隔多少步验证
      sample_every: 5                   # 每隔多少步采样
      max_train_steps: null             # 最大训练步数
      max_val_steps: null               # 最大验证步数
      tqdm_interval_sec: 1.0            # tqdm 更新间隔
      save_video: True                  # 保存视频
  6. `logging` 部分定义了日志记录的配置,包括组名、模式、项目名、标签等
    logging:
      group: ${exp_name}     # 日志组名
      id: null                 # 日志 ID
      mode: online             # 日志模式
      name: ${training.seed}      # 日志名称
      project: humanoid_mimic     # 项目名称
      resume: true             # 恢复日志
      tags:
      - train_diffusion_unet_hybrid     # 标签
      - dexdeform                       # 标签
    `checkpoint` 部分定义了检查点的保存配置
    checkpoint:
      save_ckpt: False         # 是否保存检查点
      topk:
        monitor_key: test_mean_score     # 监控键
        mode: max              # 模式
        k: 0 # top k
    
        # 格式字符串
        format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 
      save_last_ckpt: True               # 保存最后一个检查点
      save_last_snapshot: False          # 保存最后一个快照
    `multi_run` 部分定义了多次运行的目录配置
    multi_run:
      # 运行目录
      run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 
    
      # wandb 名称基础
      wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
    `hydra` 部分定义了 Hydra 框架的运行和扫描目录配置
    hydra:
      job:
        override_dirname: ${name}         # 覆盖目录名
      run:
        # 运行目录
      sweep:
        dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 
    
        # 扫描目录
        dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 
    
        # 子目录
        subdir: ${hydra.job.num} 

总体来说,这个配置文件详细定义了一个用于 3D 点云数据扩散模型训练的完整配置,包括任务、策略、数据加载、优化器、训练、日志记录和检查点保存等各个方面。

1.3 diffusion_policy_3d/workspace

1.3.1 workspace/base_workspace.py

1.3.2 workspace/dp_workspace.py

1.3.3 workspace/idp3_workspace.py:相当于实现文件(很重要)

iDP3Workspace 类继承自 BaseWorkspace,用于配置和运行 3D 点云数据的扩散模型训练,其与上面“idp3.yaml”的不同在于

  • idp3.yaml 是一个配置文件,用于定义训练任务的各种参数和设置
  • 而 idp3_workspace.py 是一个实现文件,包含了具体的训练和验证逻辑

某种程度上来讲,如果你想改程序,比如想用其他的模型、想改优化算法、或者想改这个模型的架构,那应该做什么呢?只需要改上面的yaml配置文件、workspace工作文件以及模型策略文件policy即可

改完之后,再train_policy里对应train

具体而言

  1. 首先,分别从以下文件夹里引入了一系列库

    workspace - base_workspace
    policy - diffusion_pointcloud_policy
    common - checkpoint_util
    common - json_logger
    common - pytorch_util
    model - diffusion - ema_model
    model - common - lr_scheduler
    # 从 base_workspace 模块导入 BaseWorkspace 类
    from diffusion_policy_3d.workspace.base_workspace import BaseWorkspace 
    
    # 从 diffusion_pointcloud_policy 模块导入 DiffusionPointcloudPolicy 类
    from diffusion_policy_3d.policy.diffusion_pointcloud_policy import DiffusionPointcloudPolicy 
    
    # 从 checkpoint_util 模块导入 TopKCheckpointManager 类
    from diffusion_policy_3d.common.checkpoint_util import TopKCheckpointManager 
    
    # 从 json_logger 模块导入 JsonLogger 类
    from diffusion_policy_3d.common.json_logger import JsonLogger 
    
    # 从 pytorch_util 模块导入 dict_apply 和 optimizer_to 函数
    from diffusion_policy_3d.common.pytorch_util import dict_apply, optimizer_to 
    
    # 从 ema_model 模块导入 EMAModel 类
    from diffusion_policy_3d.model.diffusion.ema_model import EMAModel 
    
    # 从 lr_scheduler 模块导入 get_scheduler 函数
    from diffusion_policy_3d.model.common.lr_scheduler import get_scheduler 

    其中,值得重点注意的是,导入了diffusion_policy_3d下policy文件夹下代码文件diffusion_pointcloud_policy.py(其来自Improved-3D-Diffusion-Policy/diffusion_policy_3d/policy/diffusion_pointcloud_policy.py)中的DiffusionPointcloudPolicy

  2. 在 __init__ 方法中,首先通过 cfg.training.seed 设置随机种子,以确保结果的可重复性
    # 定义 iDP3Workspace 类,继承自 BaseWorkspace
    class iDP3Workspace(BaseWorkspace): 
        include_keys = ['global_step', 'epoch'] # 包含的键
    
        def __init__(self, cfg: OmegaConf, output_dir=None):     # 初始化方法
            super().__init__(cfg, output_dir=output_dir)         # 调用父类的初始化方法
            
            # 设置随机种子
            seed = cfg.training.seed         # 从配置中获取随机种子
            torch.manual_seed(seed)          # 设置 PyTorch 的随机种子
            np.random.seed(seed)             # 设置 NumPy 的随机种子
            random.seed(seed)                # 设置 Python 的随机种子
    接着,使用 hydra.utils.instantiate 方法实例化模型 DiffusionPointcloudPolicy
            # 配置模型
            self.model: DiffusionPointcloudPolicy = hydra.utils.instantiate(cfg.policy) # 实例化模型
    如果配置中启用了 EMA(指数移动平均),则尝试复制模型,否则重新实例化模型
            # 初始化 EMA 模型为 None
            self.ema_model: DiffusionPointcloudPolicy = None 
            if cfg.training.use_ema: # 如果使用 EMA
                try:
                    self.ema_model = copy.deepcopy(self.model)           # 尝试复制模型
                except: # 如果复制失败
                    self.ema_model = hydra.utils.instantiate(cfg.policy) # 重新实例化模型
    然后,实例化优化器,并初始化训练状态,包括 global_step 和 epoch
            # 配置训练状态
            self.optimizer = hydra.utils.instantiate(
                cfg.optimizer, params=self.model.parameters())     # 实例化优化器
    
            # 配置训练状态
            self.global_step = 0         # 初始化全局步数为 0
            self.epoch = 0               # 初始化周期数为 0
  3. run 方法是训练的核心逻辑
    首先,深拷贝配置文件 cfg。如果启用了调试模式,则调整训练参数以进行快速测试
        def run(self): 
            cfg = copy.deepcopy(self.cfg)         # 深拷贝配置
            
            # 如果是调试模式
            if cfg.training.debug:             
                cfg.training.num_epochs = 40           # 设置训练周期数为 40
                cfg.training.max_train_steps = 10      # 设置最大训练步数为 10
                cfg.training.max_val_steps = 3         # 设置最大验证步数为 3
                cfg.training.rollout_every = 20        # 设置每隔 20 步进行 rollout
                cfg.training.checkpoint_every = 1      # 设置每隔 1 步保存检查点
                cfg.training.val_every = 1             # 设置每隔 1 步进行验证
                cfg.training.sample_every = 1          # 设置每隔 1 步进行采样
                RUN_ROLLOUT = True                     # 设置运行 rollout 为 True
                RUN_CKPT = False                       # 设置运行检查点为 False
                verbose = True                         # 设置详细模式为 True
    
            # 如果不是调试模式
            else: 
                RUN_ROLLOUT = True                     # 设置运行 rollout 为 True
                RUN_CKPT = True                        # 设置运行检查点为 True
                verbose = False                        # 设置详细模式为 False
    接着,检查是否需要从检查点恢复训练,并加载最新的检查点
            # 恢复训练
            if cfg.training.resume:                 # 如果需要恢复训练
                lastest_ckpt_path = self.get_checkpoint_path()             # 获取最新的检查点路径
                if lastest_ckpt_path.is_file():     # 如果检查点文件存在
                    print(f"Resuming from checkpoint {lastest_ckpt_path}") # 打印恢复信息
                    self.load_checkpoint(path=lastest_ckpt_path)           # 加载检查点
    然后,实例化数据集和数据加载器,并获取数据的归一化器
            # 配置数据集
            dataset = hydra.utils.instantiate(cfg.task.dataset)      # 实例化数据集
            train_dataloader = DataLoader(dataset, **cfg.dataloader) # 创建训练数据加载器
            normalizer = dataset.get_normalizer()                    # 获取数据归一化器
    
            # 配置验证数据集
            val_dataset = dataset.get_validation_dataset()                 # 获取验证数据集
            val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader) # 创建验证数据加载器
    配置学习率调度器和 EMA 模型(如果启用)
            self.model.set_normalizer(normalizer) # 设置模型的归一化器
            if cfg.training.use_ema: # 如果使用 EMA
                self.ema_model.set_normalizer(normalizer) # 设置 EMA 模型的归一化器
    
            # 配置学习率调度器
            lr_scheduler = get_scheduler(
                cfg.training.lr_scheduler,     # 获取学习率调度器
                optimizer=self.optimizer,      # 优化器
                num_warmup_steps=cfg.training.lr_warmup_steps,     # 预热步数
                num_training_steps=(
                    len(train_dataloader) * cfg.training.num_epochs) \
                        // cfg.training.gradient_accumulate_every, # 训练步数
                last_epoch=self.global_step-1 # 上一个周期
            )
    
            # 配置 EMA
            ema: EMAModel = None                 # 初始化 EMA 为 None
            if cfg.training.use_ema:             # 如果使用 EMA
                ema = hydra.utils.instantiate(
                    cfg.ema, # 实例化 EMA
                    model=self.ema_model)        # 设置 EMA 模型
    
            cfg.logging.name = str(cfg.logging.name)                  # 将日志名称转换为字符串
            cprint("-----------------------------", "yellow")         # 打印分隔线
            cprint(f"[WandB] group: {cfg.logging.group}", "yellow")   # 打印 WandB 组名
            cprint(f"[WandB] name: {cfg.logging.name}", "yellow")     # 打印 WandB 名称
            cprint("-----------------------------", "yellow")         # 打印分隔线
  4. 接下来,配置日志记录和检查点管理
            # 配置日志记录
            wandb_run = wandb.init(
                dir=str(self.output_dir),                             # 设置输出目录
                config=OmegaConf.to_container(cfg, resolve=True),     # 将配置转换为容器
                **cfg.logging                          # 日志配置
            )
            wandb.config.update(
                {
                    "output_dir": self.output_dir,     # 更新输出目录
                }
            )
    
            # 配置检查点
            topk_manager = TopKCheckpointManager(
                save_dir=os.path.join(self.output_dir, 'checkpoints'), # 设置检查点保存目录
                **cfg.checkpoint.topk                  # 检查点配置
            )
    将模型和优化器转移到指定设备(如 GPU)
            # 设备转移
            device = torch.device(cfg.training.device) # 获取训练设备
            self.model.to(device)                      # 将模型转移到设备
            if self.ema_model is not None:             # 如果 EMA 模型存在
                self.ema_model.to(device)                 # 将 EMA 模型转移到设备
            optimizer_to(self.optimizer, device)          # 将优化器转移到设备
    
            # 保存采样批次
            train_sampling_batch = None                   # 初始化训练采样批次为 None
    在训练循环中,遍历每个 epoch 和 batch
            # 训练循环
            # 设置日志路径
            log_path = os.path.join(self.output_dir, 'logs.json.txt') 
    
            # 使用 JSON 日志记录器
            with JsonLogger(log_path) as json_logger: 
                # 遍历训练周期
                for local_epoch_idx in tqdm.tqdm(range(cfg.training.num_epochs), desc=f"Training"): 
                    step_log = dict() # 初始化步骤日志
    
                    # ========= 本周期训练 ==========
                    # 初始化训练损失列表
                    train_losses = list() 
    
                    # 遍历训练数据加载器
                    for batch_idx, batch in enumerate(train_dataloader): 
                        # 记录当前时间
                        t1 = time.time() 
    
                        # 设备转移
                        batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True) if isinstance(x, torch.Tensor) else x) # 将批次数据转移到设备
                        # 如果训练采样批次为空
                        if train_sampling_batch is None: 
                            # 设置训练采样批次
                            train_sampling_batch = batch
    计算损失并进行反向传播
                        # 计算损失
                        t1_1 = time.time()      # 记录当前时间
    
                        # 计算原始损失和损失字典
                        raw_loss, loss_dict = self.model.compute_loss(batch) 
    
                        # 计算平均损失
                        loss = raw_loss / cfg.training.gradient_accumulate_every 
                        loss.backward()         # 反向传播
    
                        t1_2 = time.time()      # 记录当前时间
    
    更新优化器和 EMA 模型,并记录训练日志
                        # 优化器步进
                        if self.global_step % cfg.training.gradient_accumulate_every == 0: # 如果全局步数是梯度累积步数的倍数
                            self.optimizer.step()         # 优化器步进
                            self.optimizer.zero_grad()    # 清零梯度
                            lr_scheduler.step()           # 学习率调度器步进
    
                        # 记录当前时间
                        t1_3 = time.time() 
    
                        # 更新 EMA
                        if cfg.training.use_ema:          # 如果使用 EMA
                            ema.step(self.model)          # EMA 步进
    
                        # 记录当前时间
                        t1_4 = time.time() 
    
                        # 日志记录
                        raw_loss_cpu = raw_loss.item()          # 获取原始损失的值
                        train_losses.append(raw_loss_cpu)       # 将原始损失添加到训练损失列表
                        step_log = {
                            'train_loss': raw_loss_cpu,         # 训练损失
                            'global_step': self.global_step,    # 全局步数
                            'epoch': self.epoch, # 周期数
                            'lr': lr_scheduler.get_last_lr()[0] # 学习率
                        }
    
                        t1_5 = time.time()             # 记录当前时间
                        step_log.update(loss_dict)     # 更新步骤日志
                        t2 = time.time()               # 记录当前时间
                        
                        # 如果是详细模式
                        if verbose: 
                            # 打印总步时间
                            print(f"total one step time: {t2-t1:.3f}")     
    
                            # 打印计算损失时间
                            print(f" compute loss time: {t1_2-t1_1:.3f}") 
    
                            # 打印优化器步进时间
                            print(f" step optimizer time: {t1_3-t1_2:.3f}") 
    
                            # 打印更新 EMA 时间
                            print(f" update ema time: {t1_4-t1_3:.3f}") 
    
                            # 打印日志记录时间
                            print(f" logging time: {t1_5-t1_4:.3f}") 
    
                        # 判断是否是最后一个批次
                        is_last_batch = (batch_idx == (len(train_dataloader)-1)) 
    
                        # 如果不是最后一个批次
                        if not is_last_batch: 
                            # 最后一步的日志记录与验证和 rollout 结合
                            # 记录日志到 WandB
                            wandb_run.log(step_log, step=self.global_step) 
    
                            # 记录日志到 JSON 日志记录器
                            json_logger.log(step_log) 
                            # 全局步数加 1
                            self.global_step += 1 
    
                        # 如果达到最大训练步数
                        if (cfg.training.max_train_steps is not None) \
                            and batch_idx >= (cfg.training.max_train_steps-1): 
                            break # 跳出循环
    注意,在每个 epoch 结束时,进行模型验证——其中会调用Improved-3D-Diffusion-Policy/diffusion_policy_3d/policy/diffusion_pointcloud_policy.py中的predict_action,详见下文的「3.3 3D版本——policy/diffusion_pointcloud_policy.py:基于点云的扩散策略(与3.2节有相似)
                    # 在每个周期结束时
                    # 用周期平均值替换训练损失
                    train_loss = np.mean(train_losses) # 计算训练损失的平均值
                    step_log['train_loss'] = train_loss # 更新步骤日志中的训练损失
    
                    # ========= 本周期评估 ==========
                    policy = self.model # 设置策略为模型
                    if cfg.training.use_ema: # 如果使用 EMA
                        policy = self.ema_model # 设置策略为 EMA 模型
                    policy.eval() # 设置策略为评估模式
                    
                    # 运行验证
                    if (self.epoch % cfg.training.val_every) == 0: # 如果需要进行验证
                        with torch.no_grad(): # 禁用梯度计算                        
                            train_losses = list() # 初始化训练损失列表
                            
                            # 遍历训练数据加载器
                            for batch_idx, batch in enumerate(train_dataloader): 
                                # 将批次数据转移到设备
                                batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True) if isinstance(x, torch.Tensor) else x)       
    
                                # 获取观测数据
                                obs_dict = batch['obs']         
                                # 获取真实动作    
                                gt_action = batch['action']         
    
                                # 预测动作
                                result = policy.predict_action(obs_dict) 
    
                                # 获取预测动作
                                pred_action = result['action_pred'] 
                                # 计算均方误差
                                mse = torch.nn.functional.mse_loss(pred_action, gt_action) 
                                # 将均方误差添加到训练损失列表
                                train_losses.append(mse.item()) 
                                
                                # 如果达到最大训练步数
                                if (cfg.training.max_train_steps is not None) \
                                    and batch_idx >= (cfg.training.max_train_steps-1): 
                                    # 跳出循环
                                    break 
    
                            # 计算训练损失的总和
                            train_loss = np.sum(train_losses) 
    
                            # 记录周期平均验证损失
                            # 更新步骤日志中的训练动作均方误差
                            step_log['train_action_mse_error'] = train_loss 
    
                            # 更新步骤日志中的测试平均分数
                            step_log['test_mean_score'] = - step_log['train_action_mse_error'] 
    
                            # 打印验证损失
                            cprint(f"val loss: {train_loss:.7f}", "cyan") 
    并根据配置保存检查点
                    # 检查点
                     # 如果需要保存检查点
                    if (self.epoch % cfg.training.checkpoint_every) == 0 and cfg.checkpoint.save_ckpt:
                        # 保存检查点
                        if cfg.checkpoint.save_last_ckpt: # 如果需要保存最后一个检查点
                            self.save_checkpoint() # 保存检查点
                        if cfg.checkpoint.save_last_snapshot: # 如果需要保存最后一个快照
                            self.save_snapshot() # 保存快照
    
                        # 清理度量名称
                        metric_dict = dict() # 初始化度量字典
                        for key, value in step_log.items(): # 遍历步骤日志
                            new_key = key.replace('/', '_') # 替换度量名称中的斜杠
                            metric_dict[new_key] = value # 更新度量字典
                        
                        # 我们不能在这里复制最后一个检查点
                        # 因为 save_checkpoint 使用线程。
                        # 因此此时文件可能是空的!
                        topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) # 获取 topk 检查点路径
    
                        if topk_ckpt_path is not None: # 如果 topk 检查点路径存在
                            self.save_checkpoint(path=topk_ckpt_path) # 保存检查点到指定路径
                        cprint("checkpoint saved.", "green") # 打印检查点保存信息
                    # ========= 本周期评估结束 ==========
                    policy.train() # 设置策略为训练模式
    
                    # 周期结束
                    # 最后一步的日志记录与验证和 rollout 结合
    
                    # 记录日志到 WandB
                    wandb_run.log(step_log, step=self.global_step) 
    
                    # 记录日志到 JSON 日志记录器
                    json_logger.log(step_log) 
                    
                    # 全局步数加 1
                    self.global_step += 1 
                    # 周期数加 1
                    self.epoch += 1 
                    # 删除步骤日志
                    del step_log 
    
            # 停止 WandB 运行
            wandb_run.finish()
  5. get_model 方法用于加载最新的检查点,并返回训练好的模型
        def get_model(self):                  # 获取模型方法
            cfg = copy.deepcopy(self.cfg)     # 深拷贝配置
            
            # 设置标签为最新
            tag = "latest" 
            # tag = "best"
    
            # 获取最新的检查点路径
            lastest_ckpt_path = self.get_checkpoint_path(tag=tag) 
            
            # 如果检查点文件存在
            if lastest_ckpt_path.is_file(): 
                # 打印恢复信息
                cprint(f"Resuming from checkpoint {lastest_ckpt_path}", 'magenta') 
    
                # 加载检查点
                self.load_checkpoint(path=lastest_ckpt_path) 
    
            # 将检查点路径转换为字符串
            lastest_ckpt_path = str(lastest_ckpt_path) 
    如果启用了 EMA,则返回 EMA 模型
            policy = self.model              # 设置策略为模型
            if cfg.training.use_ema:         # 如果使用 EMA
                policy = self.ema_model      # 设置策略为 EMA 模型    
            policy.eval()         # 设置策略为评估模式
    
            return policy         # 返回策略
  6. 最后,通过 hydra.main 装饰器定义了 main 函数,实例化 iDP3Workspace 并调用 run 方法启动训练
    # Hydra 主函数装饰器
    @hydra.main( 
        # 配置路径
        config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), 
        # 配置名称
        config_name=pathlib.Path(__file__).stem)     
    
     # 主函数
    def main(cfg):
        # 实例化工作空间
        workspace = iDP3Workspace(cfg)      
    
        # 运行工作空间         
        workspace.run() 
    
    # 如果当前模块是主模块
    if __name__ == "__main__":     
        # 调用主函数
        main() 

以上便是workspace/idp3_workspace.py的全部代码,值得再次强调的是

上面说,“ 注意在每个 epoch 结束时,进行模型验证——其中会调用Improved-3D-Diffusion-Policy/diffusion_policy_3d/policy/diffusion_pointcloud_policy.py中的predict_action,详见下文的「3.3 3D版本——policy/diffusion_pointcloud_policy.py:基于点云的扩散策略(与3.2节有相似)

那怎么判定其中的predict_action,是来自哪个类的呢?


要判定 predict_action 方法是来自哪个类的,可以通过以下步骤进行分析:

  1. 查找 predict_action 的调用位置:在代码中,predict_action 方法在 run 方法的验证部分被调用:
    result = policy.predict_action(obs_dict)
  2. 确定 policy 的类型:policy 变量在 run 方法中被赋值:
    policy = self.model
    if cfg.training.use_ema:
        policy = self.ema_model
  3. 确定 self.model 和 self.ema_model 的类型:
    在 __init__ 方法中,self.model 和 self.ema_model 被初始化为DiffusionPointcloudPolicy
    self.model: DiffusionPointcloudPolicy = hydra.utils.instantiate(cfg.policy)
    self.ema_model: DiffusionPointcloudPolicy = None
    if cfg.training.use_ema:
        try:
            self.ema_model = copy.deepcopy(self.model)
        except:
            self.ema_model = hydra.utils.instantiate(cfg.policy)
  4. 查找 DiffusionPointcloudPolicy 类:DiffusionPointcloudPolicy 类在以下位置被导入:
    from diffusion_policy_3d.policy.diffusion_pointcloud_policy import DiffusionPointcloudPolicy
  5. 查看 DiffusionPointcloudPolicy 类的定义:打开 diffusion_pointcloud_policy.py 文件,查找 DiffusionPointcloudPolicy 类的定义,并查看该类中是否有 predict_action 方法——当然,一查确实有,详见下文的3.3 3D版本——policy/diffusion_pointcloud_policy.py:基于点云的扩散策略(与3.2节有相似)

通过这些步骤,可以确定 predict_action 方法是 DiffusionPointcloudPolicy 类中的方法。具体来说,DiffusionPointcloudPolicy 类定义了 predict_action 方法,并且在 run 方法中通过 policy 变量调用了该方法

1.4 diffusion_policy_3d/dataset:各种数据集及相关处理

1.4.1 dataset/base_dataset.py:低维、图像、点云、通用等4类数据集

该代码文件定义了四个基类,分别用于处理低维数据集、图像数据集、点云数据集和通用数据集。这些基类继承自 torch.utils.data.Dataset,并定义了一些抽象方法和默认行为

下面逐一阐述这4个基类

处理低维数据集:BaseLowdimDataset类

class BaseLowdimDataset(torch.utils.data.Dataset): 
    def get_validation_dataset(self) -> 'BaseLowdimDataset':  
        # 默认返回一个空的数据集
        return BaseLowdimDataset()

    def get_normalizer(self, **kwargs) -> LinearNormalizer:  
        raise NotImplementedError()  # 抛出未实现的异常

    def get_all_actions(self) -> torch.Tensor: 
        raise NotImplementedError()  # 抛出未实现的异常

    # 定义 __len__ 方法,返回数据集的长度   
    def __len__(self) -> int:  
        return 0  # 默认返回 0
    
    # 定义 __getitem__ 方法,返回一个包含观察和动作的字典
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:  
        """
        output:
            obs: T, Do  # 观察数据的形状为 (T, Do)
            action: T, Da  # 动作数据的形状为 (T, Da)
        """
        raise NotImplementedError()  # 抛出未实现的异常

另外三个类的实现差不多,就不再一一贴它们的代码了

  • 处理图像数据集:BaseImageDataset 类
  • 处理点云数据集:BasePointcloudDataset 类
  • 通用的数据集基类:BaseDataset 类

这些基类为不同类型的数据集提供了统一的接口和默认行为,子类可以继承这些基类并实现具体的方法,以处理特定类型的数据集

1.4.2 dataset/gr1_dex_dataset_3d.py:处理 3D 数据集

GR1DexDataset3D 类继承自 BaseDataset,用于处理 3D 数据集

其构造函数接受多个参数

    def __init__(self,
            zarr_path,      # 数据集路径
            horizon=1,      # 时间跨度
            pad_before=0,      # 前填充
            pad_after=0,      # 后填充
            seed=42,            # 随机种子
            val_ratio=0.0,      # 验证集比例
            max_train_episodes=None,  # 最大训练集数量
            task_name=None,       # 任务名称
            num_points=4096,      # 点云数量
            ):

在初始化过程中,使用 cprint 打印加载数据集的信息,并设置类的属性

在构造函数__init__

  1. 首先调用父类的构造函数 super().__init__() 进行初始化。然后,使用 cprint 打印加载数据集的信息,并设置类的属性 task_name 和 num_points
            super().__init__()  # 调用父类的构造函数
            cprint(f'Loading GR1DexDataset from {zarr_path}', 'green')  # 打印加载数据集的信息
            self.task_name = task_name  # 设置任务名称
    
            self.num_points = num_points  # 设置点云数量
  2. 接下来,定义一个包含 `state` 和 `action` 的 buffer_keys 列表,并将 `point_cloud` 添加到该列表中
            buffer_keys = [  # 定义缓冲区键列表
                'state',  # 状态
                'action',  # 动作
            ]
            
            buffer_keys.append('point_cloud')  # 添加点云键
    通过调用 ReplayBuffer.copy_from_path 方法,从指定路径加载数据,并生成一个 ReplayBuffer 对象
            self.replay_buffer = ReplayBuffer.copy_from_path(  # 从指定路径加载重放缓冲区
                zarr_path, keys=buffer_keys)
  3. 接着,使用 get_val_mask 方法生成验证集掩码 val_mask,并通过取反操作生成训练集掩码 train_mask
    为了控制训练集的大小,使用 downsample_mask 方法对训练集掩码进行下采样
            val_mask = get_val_mask(  # 获取验证集掩码
                n_episodes=self.replay_buffer.n_episodes,  # 重放缓冲区中的集数
                val_ratio=val_ratio,      # 验证集比例
                seed=seed)      # 随机种子
            train_mask = ~val_mask      # 训练集掩码为验证集掩码的取反
            train_mask = downsample_mask(      # 对训练集掩码进行下采样
                mask=train_mask,      # 掩码
                max_n=max_train_episodes,      # 最大训练集数量
                seed=seed)      # 随机种子
  4. 最后,创建一个 SequenceSampler 对象 self.sampler,用于从重放缓冲区中采样数据
    SequenceSampler 对象的初始化参数包括重放缓冲区 replay_buffer、时间跨度 sequence_length、填充参数 pad_before 和 pad_after 以及训练集掩码 episode_mask
            self.sampler = SequenceSampler(  # 创建序列采样器
                replay_buffer=self.replay_buffer,  # 重放缓冲区
                sequence_length=horizon,  # 序列长度
                pad_before=pad_before,      # 前填充
                pad_after=pad_after,      # 后填充
                episode_mask=train_mask)  # 训练集掩码
    构造函数还设置了类的其他属性,包括 train_mask、horizon、pad_before 和 pad_after
            self.train_mask = train_mask  # 设置训练集掩码
            self.horizon = horizon  # 设置时间跨度
            self.pad_before = pad_before  # 设置前填充
            self.pad_after = pad_after  # 设置后填充
    通过这些步骤,构造函数完成了数据集对象的初始化,为后续的数据处理和模型训练提供了基础

接下来,get_validation_dataset 方法用于生成验证数据集

它通过浅拷贝当前对象,并创建一个新的 SequenceSampler 对象,使用验证集掩码来替换训练集掩码

其次,get_normalizer 方法用于生成数据归一化器

它首先从重放缓冲区中提取 `action` 数据,并使用 LinearNormalizer 对其进行拟合。然后,为 point_cloud 和 agent_pos 创建身份归一化器,并返回归一化器对象

而剩下的方法有

  • __len__ 方法返回数据集的长度,即采样器的长度
  • _sample_to_data 方法将采样的数据转换为所需的格式
    包括将状态和点云数据转换为浮点数,并对点云数据进行均匀采样
  • __getitem__ 方法根据索引从采样器中获取数据样本,并将其转换为 PyTorch 张量
    通过 dict_apply 方法,将数据字典中的所有 NumPy 数组转换为 PyTorch 张量,并返回转换后的数据

1.4.3 dataset/gr1_dex_dataset_image.py:处理图像和深度信息

GR1DexDatasetImage 类继承自 BaseDataset,用于处理包含图像和深度信息的数据集

其构造函数__init__接受多个参数

    def __init__(self,
            zarr_path,      # 数据集路径
            horizon=1,      # 时间跨度
            pad_before=0,      # 前填充
            pad_after=0,       # 后填充
            seed=42,          # 随机种子
            val_ratio=0.0,    # 验证集比例
            max_train_episodes=None,  # 最大训练集数量
            task_name=None,      # 任务名称
            use_img=True,        # 是否使用图像
            use_depth=False,     # 是否使用深度信息
            ):

在初始化过程中,使用 cprint 打印加载数据集的信息,并设置类的属性

  1. 该类首先定义了一个包含 `state` 和 `action` 的 buffer_keys 列表
            self.task_name = task_name      # 设置任务名称
            self.use_img = use_img          # 设置是否使用图像
            self.use_depth = use_depth      # 设置是否使用深度信息
    
            buffer_keys = [      # 定义缓冲区键列表
                'state',         # 状态
                'action',        # 动作
    如果 use_img 为真,则将 `img` 添加到 buffer_keys 列表中;如果 use_depth 为真,则将 depth 添加到 buffer_keys 列表中
            if self.use_img:  # 如果使用图像
                buffer_keys.append('img')  # 添加图像键
            if self.use_depth:  # 如果使用深度信息
                buffer_keys.append('depth')  # 添加深度键
    然后,通过调用 ReplayBuffer.copy_from_path 方法从指定路径加载数据,并生成一个 ReplayBuffer 对象
            self.replay_buffer = ReplayBuffer.copy_from_path(  # 从指定路径加载重放缓冲区
                zarr_path, keys=buffer_keys)
  2. 接着,使用 get_val_mask 方法生成验证集掩码 val_mask,并通过取反操作生成训练集掩码 train_mask
    为了控制训练集的大小,使用 downsample_mask 方法对训练集掩码进行下采样
            val_mask = get_val_mask(          # 获取验证集掩码
                n_episodes=self.replay_buffer.n_episodes,  # 重放缓冲区中的集数
                val_ratio=val_ratio,          # 验证集比例
                seed=seed)  # 随机种子
            train_mask = ~val_mask          # 训练集掩码为验证集掩码的取反
            train_mask = downsample_mask(   # 对训练集掩码进行下采样
                mask=train_mask,               # 掩码
                max_n=max_train_episodes,      # 最大训练集数量
                seed=seed)                  # 随机种子
  3. 最后,创建一个 SequenceSampler 对象 self.sampler,用于从重放缓冲区中采样数据
            self.sampler = SequenceSampler(        # 创建序列采样器
                replay_buffer=self.replay_buffer,  # 重放缓冲区
                sequence_length=horizon,           # 序列长度
                pad_before=pad_before,             # 前填充
                pad_after=pad_after,          # 后填充
                episode_mask=train_mask)      # 训练集掩码
            self.train_mask = train_mask      # 设置训练集掩码
            self.horizon = horizon            # 设置时间跨度
            self.pad_before = pad_before      # 设置前填充
            self.pad_after = pad_after        # 设置后填充

接下来,get_validation_dataset 方法用于生成验证数据集

它通过浅拷贝当前对象,并创建一个新的 SequenceSampler 对象,使用验证集掩码来替换训练集掩码

get_normalizer 方法用于生成数据归一化器

  1. 它首先从重放缓冲区中提取 `action` 数据,并使用 LinearNormalizer 对其进行拟合
    如果 use_img 为真,则为 image 创建身份归一化器;如果 use_depth 为真,则为 depth 创建身份归一化器
  2. 最后,为 agent_pos 创建身份归一化器,并返回归一化器对象

至于剩下的方法和上节的gr1_dex_dataset_3d.py一样

  • __len__ 方法返回数据集的长度,即采样器的长度
  • _sample_to_data 方法将采样的数据转换为所需的格式,包括将状态数据转换为浮点数,并根据需要处理图像和深度数据
  • __getitem__ 方法根据索引从采样器中获取数据样本,并将其转换为 PyTorch 张量。通过 dict_apply 方法,将数据字典中的所有 NumPy 数组转换为 PyTorch 张量,并返回转换后的数据

第二部分 扩散模型与3D点云编码器的实现:diffusion_policy_3d/model

2.1 model/common

2.2 model/diffusion:涉及五个代码文件

2.2.1 diffusion/conditional_unet1d.py:再涉及3个代码子文件,分别实现交叉注意力、条件残差块、条件U-Net 网络

2.2.1.1 CrossAttention 类:实现交叉注意力

CrossAttention 类是一个用于实现交叉注意力机制的 PyTorch 模块

它在初始化时接受三个参数:输入维度 in_dim、条件维度 cond_dim 和输出维度 out_dim

    def __init__(self, in_dim, cond_dim, out_dim):
  • 在 __init__ 方法中,定义了三个线性投影层 query_proj、key_proj 和 value_proj,分别用于将输入 x 和条件 cond 投影到查询、键和值
            super().__init__()
            self.query_proj = nn.Linear(in_dim, out_dim)
            self.key_proj = nn.Linear(cond_dim, out_dim)
            self.value_proj = nn.Linear(cond_dim, out_dim)
  • 在 forward 方法中
    首先将输入 x 和条件 cond 投影到查询、键和值
        def forward(self, x, cond):
            # x: [batch_size, t_act, in_dim]
            # cond: [batch_size, t_obs, cond_dim]
    
            # Project x and cond to query, key, and value
            query = self.query_proj(x)  # [batch_size, horizon, out_dim]
            key = self.key_proj(cond)   # [batch_size, horizon, out_dim]
            value = self.value_proj(cond)  # [batch_size, horizon, out_dim]
    然后计算注意力权重,并通过软最大化函数进行归一化
            # Compute attention
            attn_weights = torch.matmul(query, key.transpose(-2, -1))  # [batch_size, horizon, horizon]
            attn_weights = F.softmax(attn_weights, dim=-1)
    最后,应用注意力权重到值上,得到注意力输出
            # Apply attention
            attn_output = torch.matmul(attn_weights, value)  # [batch_size, horizon, out_dim]
            
            return attn_output
2.2.1.2 ConditionalResidualBlock1D 类:条件残差块,在一维卷积网络中实现条件处理

ConditionalResidualBlock1D 类是一个条件残差块,用于在一维卷积网络中实现条件处理

它在初始化时接受多个参数,如下所示

    def __init__(self,          # 定义构造函数
                 in_channels,   # 输入通道数
                 out_channels,  # 输出通道数
                 cond_dim,      # 条件维度
                 kernel_size=3,       # 卷积核大小,默认值为3
                 n_groups=8,          # 组归一化的组数,默认值为8
                 condition_type='film'):      # 条件类型,默认值为'film'

在初始化过程中,定义了两个一维卷积块,并根据条件类型初始化条件编码器 cond_encoder

在构造函数中

  1. 首先创建了一个包含两个 Conv1dBlock 的 nn.ModuleList,每个 Conv1dBlock 包含一维卷积、组归一化和 Mish 激活函数
            self.blocks = nn.ModuleList([      # 定义一个包含两个卷积块的模块列表
                Conv1dBlock(in_channels,      # 第一个卷积块,输入通道数为 in_channels
                            out_channels,      # 输出通道数为 out_channels
                            kernel_size,       # 卷积核大小
                            n_groups=n_groups),      # 组归一化的组数
                Conv1dBlock(out_channels,      # 第二个卷积块,输入通道数为 out_channels
                            out_channels,      # 输出通道数为 out_channels
                            kernel_size,       # 卷积核大小
                            n_groups=n_groups),      # 组归一化的组数
            ])
  2. 接着,根据条件类型 condition_type 初始化条件编码器 cond_encoder
            self.condition_type = condition_type  # 设置条件类型
    
            cond_channels = out_channels  # 条件通道数初始为输出通道数
    如果条件类型为 `film`,则创建一个 nn.Sequential,包含 Mish 激活函数、线性层和 Rearrange 操作,用于预测每个通道的缩放和偏移
            if condition_type == 'film':      # 如果条件类型为 'film'
                # 预测每个通道的缩放和偏移
                cond_channels = out_channels * 2      # 条件通道数为输出通道数的两倍
                self.cond_encoder = nn.Sequential(    # 定义条件编码器
                    nn.Mish(),  # Mish 激活函数
                    nn.Linear(cond_dim, cond_channels),      # 线性层
                    Rearrange('batch t -> batch t 1'),       # 重新排列张量维度
                )
    如果条件类型为 `add`,则创建一个包含 Mish 激活函数、线性层和 Rearrange 操作的 nn.Sequential
            elif condition_type == 'add':  # 如果条件类型为 'add'
                self.cond_encoder = nn.Sequential(  # 定义条件编码器
                    nn.Mish(),  # Mish 激活函数
                    nn.Linear(cond_dim, out_channels),  # 线性层
                    Rearrange('batch t -> batch t 1'),  # 重新排列张量维度
                )
    如果条件类型为 `cross_attention_add` 或 `cross_attention_film`,则使用 CrossAttention类进行交叉注意力计算
            elif condition_type == 'cross_attention_add':  # 如果条件类型为 'cross_attention_add'
                self.cond_encoder = CrossAttention(in_channels, cond_dim, out_channels)  # 定义交叉注意力编码器
            elif condition_type == 'cross_attention_film':  # 如果条件类型为 'cross_attention_film'
                cond_channels = out_channels * 2  # 条件通道数为输出通道数的两倍
                self.cond_encoder = CrossAttention(in_channels, cond_dim, cond_channels)  # 定义交叉注意力编码器
    如果条件类型为 `mlp_film`,则创建一个包含两个 Mish 激活函数和两个线性层的 nn.Sequential
            elif condition_type == 'mlp_film':      # 如果条件类型为 'mlp_film'
                cond_channels = out_channels * 2    # 条件通道数为输出通道数的两倍
                self.cond_encoder = nn.Sequential(  # 定义条件编码器
                    nn.Mish(),  # Mish 激活函数
                    nn.Linear(cond_dim, cond_dim),  # 线性层
                    nn.Mish(),  # Mish 激活函数
                    nn.Linear(cond_dim, cond_channels),  # 线性层
                    Rearrange('batch t -> batch t 1'),   # 重新排列张量维度
                )
    如果条件类型未实现,则抛出 NotImplementedError 异常
            else:  # 如果条件类型未实现
                raise NotImplementedError(f"condition_type {condition_type} not implemented")  # 抛出未实现的异常

在上述初始化的基础上,forward 方法中

  1. 首先通过第一个卷积块处理输入 x
  2. 如果提供了条件 cond,则根据条件类型对输出进行调整

    如果条件类型为 `film`,则通过条件编码器生成缩放和偏移,并应用于输出
    如果条件类型为 `add`,则将条件编码器的输出与当前输出相加
    如果条件类型为 `cross_attention_add` 或 `cross_attention_film`,则通过交叉注意力计算生成嵌入,并应用于输出
    如果条件类型为 `mlp_film`,则通过条件编码器生成缩放和偏移,并应用于输出
  3. 最后,通过第二个卷积块处理输出,并将其与残差连接相加,返回最终输出
2.2.1.3 ConditionalUnet1D:条件一维 U-Net 网络,在一维数据上实现条件生成任务

ConditionalUnet1D 类是一个条件一维 U-Net 网络,用于在一维数据上实现条件生成任务

它在初始化时接受多个参数,如下所示

def __init__(self,      # 定义构造函数
        input_dim,      # 输入维度
        local_cond_dim=None,       # 局部条件维度,默认值为 None
        global_cond_dim=None,      # 全局条件维度,默认值为 None
        diffusion_step_embed_dim=256,      # 扩散步嵌入维度,默认值为 256
        down_dims=[256, 512, 1024],        # 下采样维度列表,默认值为 [256, 512, 1024]
        kernel_size=3,       # 卷积核大小,默认值为 3
        n_groups=8,          # 组归一化的组数,默认值为 8
        condition_type='film',      # 条件类型,默认值为 'film'
        use_down_condition=True,    # 是否使用下采样条件,默认值为 True
        use_mid_condition=True,     # 是否使用中间条件,默认值为 True
        use_up_condition=True,      # 是否使用上采样条件,默认值为 True
        ):
  • 在 __init__ 方法中,定义了扩散步编码器、局部条件编码器、中间模块、下采样模块和上采样模块,并初始化最终的卷积层
  • 在 forward 方法中,首先对时间步进行编码,然后根据条件类型对局部和全局条件进行处理,最后通过下采样、中间处理和上采样阶段生成最终输出

2.2.2 diffusion/conv1d_components.py:涉及一维卷积、下采样、上采样

该代码定义了几个用于一维卷积操作的 PyTorch 模块,包括 Downsample1d、Upsample1d 和 Conv1dBlock

  1. Downsample1d 类是一个用于一维下采样的模块。它在初始化时接受一个参数 dim,并定义了一个一维卷积层 self.conv,该卷积层的卷积核大小为 3,步幅为 2,填充为 1
    class Downsample1d(nn.Module):  
        def __init__(self, dim):      # 定义构造函数,接受一个参数 dim
            super().__init__()        # 调用父类的构造函数
            self.conv = nn.Conv1d(dim, dim, 3, 2, 1)      # 定义一个一维卷积层,卷积核大小为 3,步幅为 2,填充为 1
    在 forward 方法中,输入 x 通过卷积层进行下采样
        def forward(self, x):      # 定义前向传播函数
            return self.conv(x)    # 返回卷积后的结果
  2. Upsample1d 类是一个用于一维上采样的模块。它在初始化时同样接受一个参数 dim,并定义了一个一维反卷积层 self.conv,该反卷积层的卷积核大小为 4,步幅为 2,填充为 1
    class Upsample1d(nn.Module): 
        def __init__(self, dim):      # 定义构造函数,接受一个参数 dim
            super().__init__()        # 调用父类的构造函数
            self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)      # 定义一个一维反卷积层,卷积核大小为 4,步幅为 2,填充为 1
    在 forward 方法中,输入 x 通过反卷积层进行上采样
        def forward(self, x):      # 定义前向传播函数
            return self.conv(x)    # 返回反卷积后的结果
  3. Conv1dBlock 类是一个包含一维卷积、组归一化和 Mish 激活函数的模块
    它在初始化时接受多个参数,如下所示
    class Conv1dBlock(nn.Module): 
        '''
            Conv1d --> GroupNorm --> Mish  # 一维卷积 --> 组归一化 --> Mish 激活函数
        '''
    
        # 定义构造函数,接受输入通道数、输出通道数、卷积核大小、组数
        def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):  
            super().__init__() 
    在 __init__ 方法中,定义了一个顺序容器 self.block,其中包含一维卷积层、组归一化层和 Mish 激活函数
             # 定义一个顺序容器
             self.block = nn.Sequential(  
                # 一维卷积层,填充为卷积核大小的一半
                nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),  
    
                # 重新排列张量维度(已注释)
                # Rearrange('batch channels horizon -> batch channels 1 horizon'),  
    
                # 组归一化层
                nn.GroupNorm(n_groups, out_channels),  
    
                # 重新排列张量维度(已注释)
                # Rearrange('batch channels 1 horizon -> batch channels horizon'),  
    
                 # Mish 激活函数
                nn.Mish(), 
            )
    在 forward 方法中,输入 x 通过顺序容器中的各层进行处理
        def forward(self, x):      # 定义前向传播函数
            return self.block(x)   # 返回顺序容器处理后的结果
  4. 最后,定义了一个 test 函数,用于测试 Conv1dBlock 模块
    该函数创建了一个 Conv1dBlock 实例 cb,并生成一个形状为 `(1, 256, 16)` 的全零张量 x
    然后,将 x 传递给 cb 进行处理,并将输出存储在变量 o 中
    # 定义测试函数
    def test():  
         # 创建一个 Conv1dBlock 实例
        cb = Conv1dBlock(256, 128, kernel_size=3)     
    
        # 创建一个全零张量,形状为 (1, 256, 16)
        x = torch.zeros((1, 256, 16))                  
        # 将张量传入 Conv1dBlock 实例,并获取输出
        o = cb(x)  

2.2.3 diffusion/ema_model.py:实现模型权重的指数移动平均EMA

该代码定义了一个名为 EMAModel 的类,用于实现模型权重的指数移动平均(EMA)。EMA 是一种常用的技术,通过对模型权重进行平滑处理,可以提高模型的稳定性和泛化能力

  1. 在 EMAModel类的初始化方法 __init__ 中,接受多个参数,如下所示
    class EMAModel:          # 定义 EMAModel 类
        """
        模型权重的指数移动平均
        """
    
        # 定义构造函数
        def __init__(  
            self,
            model,                    # 模型
            update_after_step=0,      # 在多少步之后开始更新 EMA 的步数 update_after_step
            inv_gamma=1.0,            # EMA 预热的逆乘法因子,默认值为 1.0
            power=2 / 3,              # EMA 预热的指数因子,默认值为 2/3
            min_value=0.0,            # EMA 的最小衰减率,默认值为 0.0
            max_value=0.9999          # EMA 的最大衰减率,默认值为 0.9999
        ):
    初始化过程中,将传入的模型设置为评估模式,并禁用其梯度计算。还初始化了一些其他属性,如 EMA 衰减率 decay 和优化步数 optimization_step
            """
            @crowsonkb 关于 EMA 预热的笔记:
                如果 gamma=1 且 power=1,则实现简单平均。gamma=1,power=2/3 是适合训练一百万步或更多步的模型的好值
                (在 31.6K 步时达到衰减因子 0.999,在 1M 步时达到 0.9999),
                gamma=1,power=3/4 适合训练较少步数的模型(在 10K 步时达到衰减因子 0.999,在 215.4K 步时达到 0.9999)。
            参数:
                inv_gamma (float): EMA 预热的逆乘法因子。默认值: 1。
                power (float): EMA 预热的指数因子。默认值: 2/3。
                min_value (float): EMA 的最小衰减率。默认值: 0。
            """
    
            self.averaged_model = model      # 设置平均模型
            self.averaged_model.eval()       # 将平均模型设置为评估模式
            self.averaged_model.requires_grad_(False)       # 禁用平均模型的梯度计算
    
            self.update_after_step = update_after_step      # 设置在多少步之后开始更新 EMA
            self.inv_gamma = inv_gamma       # 设置 EMA 预热的逆乘法因子
            self.power = power               # 设置 EMA 预热的指数因子
            self.min_value = min_value       # 设置 EMA 的最小衰减率
            self.max_value = max_value       # 设置 EMA 的最大衰减率
    
            self.decay = 0.0                 # 初始化衰减率
            self.optimization_step = 0       # 初始化优化步数
  2. get_decay 方法用于计算 EMA 的衰减因子。它根据当前的优化步数计算衰减因子,并确保其在 min_value 和 max_value 之间。如果当前步数小于等于 0,则返回 0.0
  3. step 方法用于更新 EMA 模型的权重。该方法使用 torch.no_grad() 装饰器,以确保在更新权重时不会计算梯度
    \rightarrow  首先,计算当前步数的衰减因子
        @torch.no_grad()      # 使用 torch.no_grad() 装饰器,禁用梯度计算
        def step(self, new_model):      # 定义更新 EMA 模型的方法
            self.decay = self.get_decay(self.optimization_step)      # 获取当前步数的衰减因子
    
    \rightarrow  然后,遍历新模型和 EMA 模型的所有模块和参数,并根据参数类型和条件更新 EMA 模型的权重
    如果参数是批归一化层的参数或不需要梯度计算的参数,则直接复制新模型的参数值
            all_dataptrs = set()      # 初始化数据指针集合
    
            # 遍历新模型和平均模型的所有模块
            for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):   
    
                # 遍历模块的所有参数
                for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):  
                    # 仅迭代直接参数
                    if isinstance(param, dict):  # 如果参数是字典
                        raise RuntimeError('Dict parameter not supported')  # 抛出运行时异常
    
                    if isinstance(module, _BatchNorm):  # 如果模块是批归一化层
                        # 跳过批归一化层
                        ema_param.copy_(param.to(dtype=ema_param.dtype).data)  # 复制参数数据
    
                    # 如果参数不需要梯度计算
                    elif not param.requires_grad:  
                        ema_param.copy_(param.to(dtype=ema_param.dtype).data)  # 复制参数数据
    否则,使用 EMA 衰减因子对参数进行加权更新
                    else: 
                         # 乘以衰减因子
                        ema_param.mul_(self.decay)     
                        # 加上参数数据乘以 (1 - 衰减因子)
                        ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)      
    \rightarrow  最后,增加优化步数
            # 验证遍历模块然后参数与递归遍历参数是否相同
            self.optimization_step += 1      # 增加优化步数

通过这种方式,EMAModel 类可以在训练过程中平滑地更新模型权重,从而提高模型的稳定性和性能

2.2.4 diffusion/mask_generator.py

该代码片段定义了几个用于生成掩码的函数和类,这些掩码生成器类通过不同的配置和条件,生成适用于各种深度学习任务的掩码,方便模型处理不同的输入维度和条件

// 待更

2.2.5 diffusion/positional_embedding.py:为输入数据添加位置信息

SinusoidalPosEmb 类是一个用于生成正弦位置嵌入的 PyTorch 模块,用于为输入数据添加位置信息

首先,在 __init__ 方法中,接受一个参数 dim,表示嵌入的维度。调用 super().__init__() 初始化父类 nn.Module,并将 dim 存储为实例属性

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

其次,forward 方法用于计算位置嵌入

  1. 首先,获取输入张量 x 的设备信息 device
        def forward(self, x):
            device = x.device
  2. 然后,计算嵌入维度的一半 half_dim
            half_dim = self.dim // 2
  3. 接下来,计算一个常数 emb,该常数用于缩放位置索引
            emb = math.log(10000) / (half_dim - 1)
    对应的公式为

    然后使用 torch.arange 生成一个从 0 到 half_dim 的张量,并将其乘以 `-emb`,然后通过 torch.exp 计算指数
            emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
    其对应的公式为

    简化下是
  4. 接着,将输入张量 x 扩展维度并与生成的指数张量相乘
            emb = x[:, None] * emb[None, :]
    对应公式为
  5. 最后,通过 torch.cat 将正弦和余弦嵌入拼接在一起,得到最终的嵌入张量
            emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
            return emb
    最终对应的公式为
    \begin{array}{c} \mathrm{PE}(p o s, 2 i)=\sin \left(p o s \cdot 10000^{-2 i / \operatorname{dim}}\right) \\ \mathrm{PE}(p o s, 2 i+1)=\cos \left(p o s \cdot 10000^{-2 i / \operatorname{dim}}\right) \end{array}

注意,transformer原始论文中对位置编码的公式为

\begin{aligned} P E_{(p o s, 2 i)} & =\sin \left(p o s / 10000^{2 i / d_{\text {model }}}\right) \\ P E_{(p o s, 2 i+1)} & =\cos \left(p o s / 10000^{2 i / d_{\text {model }}}\right) \end{aligned}


如不太理解,详见此文《一文通透位置编码:从标准位置编码、旋转位置编码RoPE到ALiBi、LLaMA 2 Long(含NTK-aware简介)

2.3 model/vision

2.4 model/vision_3d

2.4.1 vision_3d/multi_stage_pointnet.py:对点云数据进行编码

该代码定义了一个名为 MultiStagePointNetEncoder 的 PyTorch 模块,用于对点云数据进行编码。该模块包含两个辅助函数 meanpool 和 maxpool,以及一个主要的编码器类 MultiStagePointNetEncoder

  1. meanpool 函数用于在指定维度上对输入张量 x 进行平均池化操作
    def meanpool(x, dim=-1, keepdim=False):
        out = x.mean(dim=dim, keepdim=keepdim)
        return out
  2. maxpool 函数用于在指定维度上对输入张量 x 进行最大池化操作
    def maxpool(x, dim=-1, keepdim=False):
        out = x.max(dim=dim, keepdim=keepdim).values
        return out
  3. MultiStagePointNetEncoder 类继承自 nn.Module,用于实现多阶段的点云编码器。其构造函数接受多个参数,如下所示
    class MultiStagePointNetEncoder(nn.Module):  
    
         # 定义构造函数,接受隐藏维度、输出通道数、层数和其他参数
        def __init__(self, h_dim=128, out_channels=128, num_layers=4, **kwargs): 
            super().__init__() 
    在初始化过程中,__init__定义了激活函数 LeakyReLU、输入卷积层 conv_in、多个隐藏层 layers 和全局层 global_layers,以及输出卷积层 conv_out
    
            self.h_dim = h_dim  # 设置隐藏维度
            self.out_channels = out_channels  # 设置输出通道数
            self.num_layers = num_layers  # 设置层数
    
            # 定义 LeakyReLU 激活函数
            self.act = nn.LeakyReLU(negative_slope=0.0, inplace=False)  
    
             # 定义输入卷积层,输入通道数为 3,输出通道数为 h_dim,卷积核大小为 1
            self.conv_in = nn.Conv1d(3, h_dim, kernel_size=1) 
    
            # 定义两个模块列表,分别用于存储局部卷积层和全局卷积层
            self.layers, self.global_layers = nn.ModuleList(), nn.ModuleList()  
    
            # 遍历层数
            for i in range(self.num_layers): 
                # 添加局部卷积层,输入和输出通道数均为 h_dim,卷积核大小为 1
                self.layers.append(nn.Conv1d(h_dim, h_dim, kernel_size=1))  
    
                # 添加全局卷积层,输入通道数为 h_dim * 2,输出通道数为 h_dim,卷积核大小为 1
                self.global_layers.append(nn.Conv1d(h_dim * 2, h_dim, kernel_size=1))  
    
            # 定义输出卷积层,输入通道数为 h_dim * 层数,输出通道数为 out_channels,卷积核大小为 1
            self.conv_out = nn.Conv1d(h_dim * self.num_layers, out_channels, kernel_size=1)  
    在 forward 方法中
    首先,将输入张量 x 的维度进行转换
    然后,通过输入卷积层和激活函数进行初步处理
    接着,遍历每一层,对输入进行卷积和激活处理,并计算全局特征,将其与当前特征拼接。将所有层的特征拼接后,通过输出卷积层进行处理
    最后,在指定维度上进行最大池化,得到全局特征 x_global 并返回

该编码器模块通过多层卷积和全局特征提取,能够有效地对点云数据进行编码,提取出有用的全局特征。

2.4.2 vision_3d/point_process.py:针对点云的打乱/填充/采样操作(含NumPy和PyTorch实现)

该代码提供了一些用于点云处理的 PyTorch 和 NumPy 实现——点云处理在计算机视觉和3D建模中非常重要,特别是在处理和分析3D数据时

  1. 首先,导入了必要的库 torch 和 numpy
    然后,定义了一个 __all__ 列表,指定了该模块中可以被外部导入的函数,包括 shuffle_point_torch、pad_point_torch 和 uniform_sampling_torch
  2. 对点云数据进行随机打乱:shuffle_point_numpy。它接受一个形状为 `(B, N, C)` 的点云张量,其中 B 是批量大小,N 是点的数量,C 是每个点的特征维度。函数通过 np.random.permutation 生成一个随机排列的索引,并返回打乱后的点云
  3. 对点云数据进行填充:pad_point_numpy
    如果点的数量少于指定的 num_points,则用零点进行填充。填充后,调用 shuffle_point_numpy 函数对点云进行随机打乱
  4. 对点云数据进行均匀采样:uniform_sampling_numpy
    如果点的数量少于指定的 num_points,则调用 pad_point_numpy 进行填充。否则,通过 np.random.permutation 生成随机索引,并返回采样后的点云
  5. 打乱之shuffle_point_torch 函数是 shuffle_point_numpy 的 PyTorch 实现
    它使用 torch.randperm 生成随机排列的索引,并返回打乱后的点云
  6. 填充之pad_point_torch 函数是 pad_point_numpy 的 PyTorch 实现
    它首先检查点的数量是否少于指定的 num_points,如果是,则用零点进行填充。填充后,调用 shuffle_point_torch 函数对点云进行随机打乱
  7. 采样之uniform_sampling_torch 函数是 uniform_sampling_numpy 的 PyTorch 实现
    如果点的数量等于指定的 num_points,则直接返回点云。如果点的数量少于指定的 num_points,则调用 pad_point_torch 进行填充。否则,通过 torch.randperm 生成随机索引,并返回采样后的点云

这些函数为点云数据的处理提供了基础操作,包括随机打乱、填充和均匀采样,适用于不同的框架——NumPy 和 PyTorch

2.4.3 vision_3d/pointnet_extractor.py:包含点云编码器iDP3Encoder的实现

该代码片段定义了一个用于创建多层感知机(MLP)的函数 create_mlp,以及两个编码器类 StateEncoder 和 iDP3Encoder,用于处理状态和点云数据

首先,create_mlp 函数用于创建一个多层感知机(MLP),即一系列全连接层,每个全连接层后面跟随一个激活函数

  1. 函数接受五个参数:如下所示
    def create_mlp( 
            input_dim: int,       # 输入维度
            output_dim: int,      # 输出维度
            net_arch: List[int],  # 神经网络的架构,表示每层的单元数
            activation_fn: Type[nn.Module] = nn.ReLU,      # 每层之后使用的激活函数,默认值为 nn.ReLU
            squash_output: bool = False,       # 是否使用 Tanh 激活函数压缩输出,默认值为 False
    ) -> List[nn.Module]:                      # 返回值为 nn.Module 的列表
  2. 函数首先根据 net_arch 创建第一层全连接层和激活函数
        if len(net_arch) > 0:
            modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
        else:
            modules = []
    然后遍历 net_arch 创建中间层
        for idx in range(len(net_arch) - 1):
            modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
            modules.append(activation_fn())
    最后添加输出层和可选的 Tanh 激活函数
        if output_dim > 0:
            last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
            modules.append(nn.Linear(last_layer_dim, output_dim))
        if squash_output:
            modules.append(nn.Tanh())
    返回值是一个包含所有层的模块列表
        return modules

其次,StateEncoder 类继承自 nn.Module,用于对状态数据进行编码。其构造函数接受三个参数:如下所示

class StateEncoder(nn.Module):  
    def __init__(self, 
                 observation_space: Dict,           # 观察空间的字典
                 state_mlp_size=(64, 64),           # 状态 MLP 的大小,默认值为 (64, 64)
                 state_mlp_activation_fn=nn.ReLU):  # 状态 MLP 的激活函数,默认值为 nn.ReLU
        super().__init__()  
  • 在初始化过程中,首先获取状态的形状,并根据 state_mlp_size 创建 MLP
            self.state_key = 'full_state'     # 设置状态键
            self.state_shape = observation_space[self.state_key]      # 获取状态的形状
            cprint(f"[StateEncoder] state shape: {self.state_shape}", "yellow")  # 打印状态形状
            
            if len(state_mlp_size) == 0:          # 如果状态 MLP 的大小为空
                raise RuntimeError(f"State mlp size is empty")  # 抛出运行时异常
            elif len(state_mlp_size) == 1:        # 如果状态 MLP 的大小为 1
                net_arch = [] 
            else:  
                net_arch = state_mlp_size[:-1]    # 网络架构为状态 MLP 大小的前 n-1 个元素
            output_dim = state_mlp_size[-1]       # 输出维度为状态 MLP 大小的最后一个元素
    
            self.state_mlp = nn.Sequential(*create_mlp(self.state_shape[0], output_dim, net_arch, state_mlp_activation_fn))       # 创建状态 MLP
    
            cprint(f"[StateEncoder] output dim: {output_dim}", "red")  # 打印输出维度
            self.output_dim = output_dim      # 设置输出维度
  • forward 方法接受一个包含状态数据的字典 observations,并通过 MLP 对状态进行编码,返回编码后的特征

最后,iDP3Encoder 类同样继承自 nn.Module,用于对点云数据和状态数据进行联合编码

其构造函数接受多个参数,包括

    def __init__(self, 
                 observation_space: Dict,      # 观察空间的字典
                 state_mlp_size=(64, 64),      # 状态 MLP 的大小
                 state_mlp_activation_fn=nn.ReLU,  # 状态 MLP 的激活函数
                 pointcloud_encoder_cfg=None,      # 点云编码器的配置
                 use_pc_color=False,               # 是否使用点云颜色
                 pointnet_type='dp3_encoder',      # 点网类型
                 point_downsample=True,            # 是否对点云进行下采样
                 ):
  1. 在初始化过程中,设置了状态和点云的键值,并根据配置初始化点云预处理方法和点网编码器
            super().__init__()      # 调用父类的构造函数
            self.state_key = 'agent_pos'          # 状态键
            self.point_cloud_key = 'point_cloud'  # 点云键
            self.n_output_channels = pointcloud_encoder_cfg.out_channels  # 输出通道数
    在构造函数中,首先获取点云和状态的形状,并根据配置选择点云预处理方法
            self.point_cloud_shape = observation_space[self.point_cloud_key]  # 获取点云的形状
            self.state_shape = observation_space[self.state_key]    # 获取状态的形状
    
            self.num_points = pointcloud_encoder_cfg.num_points     # 点的数量,默认为 4096
    如果 pointnet_type 为 "multi_stage_pointnet",则导入并实例化 MultiStagePointNetEncoder 作为点云特征提取器
    否则,抛出 NotImplementedError 异常
            self.downsample = point_downsample      # 是否对点云进行下采样
            if self.downsample:  # 如果需要下采样
                self.point_preprocess = point_process.uniform_sampling_torch  # 使用均匀采样
            else:      # 否则
                self.point_preprocess = nn.Identity()        # 使用 Identity 层
            
            if pointnet_type == "multi_stage_pointnet":      # 如果点网类型为 "multi_stage_pointnet"
                from .multi_stage_pointnet import MultiStagePointNetEncoder      # 导入 MultiStagePointNetEncoder
                self.extractor = MultiStagePointNetEncoder(out_channels=pointcloud_encoder_cfg.out_channels)  # 实例化点云特征提取器
            else:      # 否则
                raise NotImplementedError(f"pointnet_type: {pointnet_type}")  # 抛出未实现的异常
    接着,根据 state_mlp_size 创建状态 MLP,并计算输出通道数
            if len(state_mlp_size) == 0:      # 如果状态 MLP 的大小为空
                raise RuntimeError(f"State mlp size is empty")  # 抛出运行时异常
            elif len(state_mlp_size) == 1:    # 如果状态 MLP 的大小为 1
                net_arch = []      # 网络架构为空
            else:      # 否则
                net_arch = state_mlp_size[:-1]      # 网络架构为状态 MLP 大小的前 n-1 个元素
            output_dim = state_mlp_size[-1]         # 输出维度为状态 MLP 大小的最后一个元素
    
            self.n_output_channels  += output_dim  # 输出通道数加上输出维度
            self.state_mlp = nn.Sequential(*create_mlp(self.state_shape[0], output_dim, net_arch, state_mlp_activation_fn))      # 创建状态 MLP
    
            cprint(f"[DP3Encoder] output dim: {self.n_output_channels}", "red")  # 打印输出通道数
  2. forward 方法用于根据输入的观察字典 observations 生成编码特征
    首先,获取点云数据并检查其形状是否为三维。如果需要下采样,则对点云数据进行预处理
        def forward(self, observations: Dict) -> torch.Tensor:  # 定义前向传播函数
            points = observations[self.point_cloud_key]  # 获取点云数据
            assert len(points.shape) == 3, cprint(f"point cloud shape: {points.shape}, length should be 3", "red")      # 确保点云数据的形状为三维
    
            if self.downsample:      # 如果需要下采样
                points = self.point_preprocess(points, self.num_points)  # 对点云数据进行预处理
    然后,通过点云特征提取器提取点云特征
            pn_feat = self.extractor(points)  # 提取点云特征
    接着,获取状态数据并通过状态 MLP 进行编码
            state = observations[self.state_key]    # 获取状态数据
            state_feat = self.state_mlp(state)      # 对状态数据进行编码
    最后,将点云特征和状态特征拼接在一起,返回最终的编码特征
            final_feat = torch.cat([pn_feat, state_feat], dim=-1)      # 拼接点云特征和状态特征
            return final_feat          # 返回最终的编码特征

output_shape 方法返回编码器的输出通道数。

总的来说,iDP3Encoder 类通过点云特征提取器和状态 MLP,实现了对点云数据和状态数据的联合编码,适用于各种深度学习任务

第三部分 基于图像和点云的扩散策略:diffusion_policy_3d/policy(相当于包含2D和3D两个版本)

3.1 policy/base_policy.py:基类策略模型

该代码定义了一个名为 BasePolicy 的基类,用于实现策略模型。该类继承自 ModuleAttrMixin,并包含一些方法和接口,用于处理策略模型的基本功能

  1. 首先,导入了必要的库和模块,包括 Dict 类型提示、torch 和 torch.nn,以及自定义的 ModuleAttrMixin 和 LinearNormalizer 模块
  2. BasePolicy 类的构造函数接受一个关键字参数 `shape_meta`,该参数在配置文件中定义(例如 `config/task/*_image.yaml`)。然而,构造函数的具体实现并未在代码中展示。

    predict_action 方法是一个抽象方法,用于根据输入的观察字典 obs_dict 预测动作
    obs_dict 是一个字典,键为字符串,值为形状为 `(B, To, *)` 的张量。该方法的返回值是一个字典,键为字符串,值为形状为 `(B, Ta, Da)` 的张量。由于这是一个抽象方法,具体实现需要在子类中完成,因此在该方法中抛出了 NotImplementedError 异常
        def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:  # 定义 predict_action 方法,接受一个包含观察数据的字典,返回一个包含动作数据的字典
            """
            obs_dict:  # 观察数据字典
                str: B,To,*  # 键为字符串,值为形状为 (B, To, *) 的张量
            return: B,Ta,Da  # 返回形状为 (B, Ta, Da) 的张量
            """
            raise NotImplementedError()  # 抛出未实现的异常
    reset 方法用于重置状态,对于有状态的策略模型非常重要。该方法在基类中实现为空方法,具体实现可以在子类中覆盖

    set_normalizer 方法用于设置归一化器 normalizer,该归一化器是 LinearNormalizer 类型。由于没有标准的训练接口,该方法在基类中同样抛出了 NotImplementedError 异常,具体实现需要在子类中完成

总的来说,BasePolicy 类提供了一个策略模型的基本框架,定义了预测动作、重置状态和设置归一化器的方法接口。具体的策略模型需要继承该基类,并实现这些抽象方法

3.2 2D版本——policy/diffusion_image_policy.py:基于图像的扩散策略

DiffusionImagePolicy 类继承自 BasePolicy,用于实现基于扩散模型的图像策略

3.2.1 __init__

该类的构造函数接受多个参数,包括且不限于

    def __init__(self, 
            shape_meta: dict,   
            noise_scheduler: DDPMScheduler,      // 噪声调度器
            horizon,                             // 时间跨度
            n_action_steps,                      // 动作步数
            n_obs_steps,                         // 观察步数
            num_inference_steps=None,            // 推理步数
            obs_as_global_cond=True,             // 是否将观察作为全局条件
            crop_shape=(76, 76),                 // 裁剪形状
            diffusion_step_embed_dim=256,        // 扩散步嵌入维度
            down_dims=(256,512,1024),            // 下采样维度
            kernel_size=5,                       // 卷积核大小
            n_groups=8,                          // 组数
            condition_type='film',               // 条件类型
            use_depth=False,                     // 是否使用深度信息
            use_depth_only=False,                // 是否仅使用深度信息
            obs_encoder: TimmObsEncoder = None,  // 观察编码器
            # parameters passed to step
            **kwargs):

在初始化过程中,解析形状元数据,设置动作和观察的形状,并根据配置创建模型和相关组件

3.2.2 forward:根据输入的观察字典 obs_dict 生成动作

forward 方法用于根据输入的观察字典 obs_dict 生成动作

  1. 首先,复制输入的观察字典 obs_dict,以避免对原始数据进行修改
        def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:  # 定义前向传播函数
            obs_dict = obs_dict.copy()  # 复制观察字典
  2. 接着,对输入进行归一化处理,并将图像数据的像素值从 0-255 范围缩放到 0-1 范围
            # 归一化输入
            nobs = self.normalizer.normalize(obs_dict)  # 归一化观察字典
    
            nobs['image'] /= 255.0  # 将图像归一化到 [0, 1] 范围
    如果图像数据的最后一个维度为 3(表示 RGB 图像),则根据图像数据的维度进行维度转换
            if nobs['image'].shape[-1] == 3:      # 如果图像的最后一个维度为 3
                if len(nobs['image'].shape) == 5:      # 如果图像的形状长度为 5
                    nobs['image'] = nobs['image'].permute(0, 1, 4, 2, 3)   # 重新排列图像维度
                if len(nobs['image'].shape) == 4:      # 如果图像的形状长度为 4
                    nobs['image'] = nobs['image'].permute(0, 3, 1, 2)      # 重新排列图像维度
    如果使用深度信息且不只使用深度信息,则将深度信息与图像数据沿着通道维度拼接
            if self.use_depth and not self.use_depth_only:      # 如果使用深度信息但不只使用深度信息
                nobs['image'] = torch.cat([nobs['image'], nobs['depth'].unsqueeze(-3)], dim=-3)      # 将深度信息添加到图像中
    如果只使用深度信息,则将深度信息作为图像数据
            if self.use_depth and self.use_depth_only:      # 如果仅使用深度信息
                nobs['image'] = nobs['depth'].unsqueeze(-3)      # 将深度信息作为图像
  3. 接下来,从归一化后的观察字典中获取一个值,并提取其形状信息,包括批量大小 B 和观察步数 To
            value = next(iter(nobs.values()))  # 获取观察字典中的第一个值
            B, To = value.shape[:2]  # 获取批量大小和观察步数
    然后,设置时间跨度 T、动作维度 Da、观察特征维度 Do 和观察步数 To
            T = self.horizon          # 设置时间跨度
            Da = self.action_dim      # 设置动作维度
            Do = self.obs_feature_dim      # 设置观察特征维度
            To = self.n_obs_steps          # 设置观察步数
    构建输入数据时,获取设备信息 device 和数据类型 dtype
            # 构建输入
            device = self.device      # 获取设备
            dtype = self.dtype        # 获取数据类型
    处理不同的观察传递方式时,初始化局部条件 local_cond 和全局条件 global_cond
            # 处理不同的观察传递方式
            local_cond = None       # 局部条件
            global_cond = None      # 全局条件
    通过全局特征进行条件处理时,使用 dict_apply 函数对观察数据进行处理,并通过观察编码器 self.obs_encoder 提取观察特征
     
            # 通过全局特征进行条件处理
            # 获取前 n_obs_steps 步的观察数据
            this_nobs = dict_apply(nobs, lambda x: x[:,:self.n_obs_steps,...])  
    
            # 编码观察数据
            nobs_features = self.obs_encoder(this_nobs)
    将提取的观察特征重新调整形状为 `(B, Do)`,并赋值给 global_cond
            # 重新调整形状为 B, Do
            global_cond = nobs_features.reshape(B, -1)   # 重新调整观察特征的形状
    创建一个空的动作数据张量 cond_data 和一个全为 `False` 的掩码张量 cond_mask
            # 创建空的动作数据
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)  # 创建空的动作数据张量
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)              # 创建空的动作掩码张量
    然后,调用 conditional_sample 方法进行采样(下一节 会解释该方法),传入动作数据、掩码、局部条件和全局条件等参数
            # 运行采样
            # 调用 conditional_sample 方法进行采样
            nsample = self.conditional_sample(  
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                **self.kwargs)
    采样完成后,对预测的动作数据进行反归一化处理
            # 反归一化预测
            naction_pred = nsample[...,:Da]  # 获取采样结果中的动作预测
            action_pred = self.normalizer['action'].unnormalize(naction_pred)  # 反归一化动作预测
    最后,从预测的动作数据中提取所需的动作步数,并返回最终的动作
            # 获取动作
            start = To - 1      # 设置起始步数
            end = start + self.n_action_steps      # 设置结束步数
            action = action_pred[:,start:end]      # 获取动作预测结果
            
            # 获取预测结果
            return action       # 返回动作预测结果

通过这些步骤,forward 方法实现了从输入观察数据生成动作的过程,适用于基于扩散模型的图像策略

3.2.3 conditional_sample:给定条件下的采样

conditional_sample 方法用于在给定条件下生成样本

  1. 该方法接受多个参数,具体如下所示
    def conditional_sample(self,      # 定义 conditional_sample 方法
                condition_data, condition_mask,      # 接受条件数据和条件掩码
                local_cond=None, global_cond=None,   # 接受局部条件和全局条件,默认值为 None
                generator=None,       # 接受随机数生成器,默认值为 None
    
                # 此外,还可以传递其他关键字参数 kwargs 给调度器的 step 方法
                **kwargs
                ):
  2. 首先,方法获取模型 self.model 和噪声调度器 self.noise_scheduler
            model = self.model      # 获取模型
            scheduler = self.noise_scheduler  # 获取噪声调度器
  3. 然后,使用 torch.randn 函数生成一个与 condition_data 形状相同的随机轨迹张量 trajectory,并指定数据类型、设备和随机数生成器
            trajectory = torch.randn(      # 生成一个与条件数据形状相同的随机轨迹张量
                size=condition_data.shape,        # 形状与条件数据相同
                dtype=condition_data.dtype,       # 数据类型与条件数据相同
                device=condition_data.device,     # 设备与条件数据相同
                generator=generator)       # 使用指定的随机数生成器
  4. 接下来,设置调度器的时间步数 scheduler.set_timesteps(self.num_inference_steps)。在每个时间步 t 中
            # 设置时间步数
            scheduler.set_timesteps(self.num_inference_steps)
  5. 首先应用条件数据,将 condition_data 中满足条件掩码 condition_mask 的部分赋值给轨迹张量 trajectory
            # 遍历调度器的时间步数        
            for t in scheduler.timesteps:         
                # 1. 应用条件
                # 将条件数据中满足条件掩码的部分赋值给轨迹张量
                trajectory[condition_mask] = condition_data[condition_mask]  
    然后,使用模型预测输出 model_output,传入当前轨迹、时间步 t、局部条件 local_cond 和全局条件 global_cond
                # 2. 预测模型输出
                model_output = model(trajectory, t,      # 使用模型预测输出
                    local_cond=local_cond, global_cond=global_cond)  # 传入当前轨迹、时间步、局部条件和全局条件
    接着,调用调度器的 step 方法,计算前一个时间步的样本 `x_t-1`,并更新轨迹张量 trajectory
                # 3. 计算前一个时间步的样本:x_t -> x_t-1
                trajectory = scheduler.step(      # 调用调度器的 step 方法
                    model_output, t, trajectory,  # 传入模型输出、时间步和当前轨迹
                    generator=generator,          # 使用指定的随机数生成器
                    # **kwargs
                    ).prev_sample          # 获取前一个时间步的样本
  6. 最后,确保条件数据被强制应用,再次将 condition_data 中满足条件掩码 condition_mask 的部分赋值给轨迹张量 trajectory
            # 最后确保条件被强制应用
            # 再次将条件数据中满足条件掩码的部分赋值给轨迹张量
            trajectory[condition_mask] = condition_data[condition_mask]  
  7. 方法返回最终生成的轨迹张量 trajectory
            return trajectory      # 返回最终生成的轨迹张量

通过这些步骤,conditional_sample 方法实现了在给定条件下的样本生成过程,适用于基于扩散模型的图像策略

3.2.4 predict_action:根据输入的观察字典obs_dict预测动作

predict_action 方法用于根据输入的观察字典预测动作,该方法与 forward 方法类似

  1. 首先对输入进行归一化处理,并根据配置处理图像和深度信息
  2. 然后,构建输入数据,包括局部和全局条件
  3. 通过调用 conditional_sample 方法进行采样,得到未归一化的动作预测,并将其反归一化,返回最终的动作和动作预测结果

3.2.5 compute_loss:计算给定批次数据的损失

set_normalizer 方法用于设置归一化器 normalizer,通过加载归一化器的状态字典实现

compute_loss 方法用于计算给定批次数据的损失

  1. 首先,对输入进行归一化处理,并根据配置处理图像和深度信息
  2. 然后,构建输入数据,包括局部和全局条件。生成掩码,并添加噪声到轨迹中。应用条件数据,预测模型输出,并根据调度器的配置计算目标
  3. 最后,计算均方误差损失,并返回损失值

总的来说,DiffusionImagePolicy 类通过扩散模型和条件采样,实现了基于图像的策略生成和训练

3.3 3D版本——policy/diffusion_pointcloud_policy.py:基于点云的扩散策略(与3.2节有相似)

DiffusionPointcloudPolicy 类继承自 BasePolicy,用于实现基于扩散模型的点云策略

3.3.1 __init__

该类的构造函数接受多个参数,包括

    def __init__(self, 
            shape_meta: dict,
            noise_scheduler: DDPMScheduler,    // 噪声调度器
            horizon,                           // 时间跨度
            n_action_steps,                    // 动作步数
            n_obs_steps,                       // 观察步数
            num_inference_steps=None,          // 推理步数
            obs_as_global_cond=True,           // 是否将观察作为全局条件
            diffusion_step_embed_dim=256,      // 扩散步嵌入维度
            down_dims=(256,512,1024),          // 下采样维度
            kernel_size=5,                     // 卷积核大小
            n_groups=8,                        // 组数
            condition_type="film",             // 条件类型
            use_down_condition=True,           // 是否使用下采样条件
            use_mid_condition=True,            // 是否使用中间条件
            use_up_condition=True,             // 是否使用上采样条件
            use_pc_color=False,                // 是否使用点云颜色
            pointnet_type="pointnet",          // 点网类型
            pointcloud_encoder_cfg=None,       // 点云编码器配置
            point_downsample=False,            // 是否对点云进行下采样

在初始化过程中,解析形状元数据,设置动作和观察的形状,并根据配置创建模型和相关组件。

3.3.2 forward:根据输入的观察字典 obs_dict 生成动作

forward 方法用于根据输入的观察字典 obs_dict 生成动作

  1. 首先,对输入进行归一化处理,并根据配置处理点云和颜色信息
  2. 然后,构建输入数据,包括局部和全局条件
  3. 通过调用 conditional_sample 方法进行采样,得到未归一化的动作预测,并将其反归一化,返回最终的动作

3.3.3 conditional_sample:在给定条件下进行采样

conditional_sample 方法用于在给定条件下进行采样

  1. 首先,生成一个随机的轨迹张量,并设置调度器的时间步数
  2. 在每个时间步中,应用条件数据,预测模型输出,并计算前一个时间步的样本
  3. 最后,确保条件数据被强制应用,返回最终的轨迹

3.3.4 predict_action:根据输入的观察字典 obs_dict 生成动作(与forward类似)

predict_action 方法用于根据输入的观察字典 obs_dict 生成动作,该方法与上面的forward类似

  1. 首先,对输入的观察字典进行归一化处理
    # 定义 predict_action 方法,接受一个包含观察数据的字典,返回一个包含动作数据的字典
    def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 
            """
            obs_dict: 必须包含 "obs" 键
            result: 必须包含 "action" 键
            """
    
            # 归一化输入
            nobs = self.normalizer.normalize(obs_dict)      # 对观察数据进行归一化处理
    对于点云数据,如果不使用点云颜色,则只保留前三个通道(通常是坐标信息);
            if not self.use_pc_color:  # 如果不使用点云颜色
                nobs['point_cloud'] = nobs['point_cloud'][..., :3]    # 只保留前三个通道(通常是坐标信息)
    如果使用点云颜色,则将颜色信息归一化到 0-1 范围
            if self.use_pc_color:  # 如果使用点云颜色
                nobs['point_cloud'][..., 3:] /= 255.0      # 将颜色信息归一化到 0-1 范围
  2. 接下来,从归一化后的观察字典中获取一个值,并提取其形状信息,包括批量大小 B 和观察步数 To
            value = next(iter(nobs.values()))    # 获取归一化后的观察数据中的一个值
            B, To = value.shape[:2]              # 提取批量大小和观察步数
    然后,设置时间跨度 T、动作维度 Da、观察特征维度 Do 和观察步数 To
            T = self.horizon               # 设置时间跨度
            Da = self.action_dim           # 设置动作维度
            Do = self.obs_feature_dim      # 设置观察特征维度
            To = self.n_obs_steps          # 设置观察步数
    构建输入数据时,获取设备信息 device 和数据类型 dtype
            # 构建输入
            device = self.device      # 获取设备信息
            dtype = self.dtype        # 获取数据类型
    处理不同的观察传递方式时,初始化局部条件 local_cond 和全局条件 global_cond
            # 处理不同的观察传递方式
            local_cond = None       # 初始化局部条件
            global_cond = None      # 初始化全局条件
    如果将观察作为全局条件 obs_as_global_cond,则通过全局特征进行条件处理。使用 dict_apply 函数对观察数据进行处理,并通过观察编码器 self.obs_encoder 提取观察特征
            if self.obs_as_global_cond:          # 如果将观察作为全局条件
                # 通过全局特征进行条件处理
                this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))       # 对观察数据进行处理
                nobs_features = self.obs_encoder(this_nobs)      # 提取观察特征
    根据条件类型 condition_type,将提取的观察特征调整形状为 `(B, self.n_obs_steps, -1)` 或 `(B, -1)`,并赋值给 global_cond
                if "cross_attention" in self.condition_type:      # 如果条件类型为 "cross_attention"
                    # 作为序列处理
                    global_cond = nobs_features.reshape(B, self.n_obs_steps, -1)  # 将观察特征调整形状为 (B, self.n_obs_steps, -1)
    
                else:  
                    # 重新调整形状为 (B, Do)
                    global_cond = nobs_features.reshape(B, -1)      # 将观察特征调整形状为 (B, -1)
    创建一个空的动作数据张量 cond_data 和一个全为 `False` 的掩码张量 cond_mask
                # 空的动作数据
                # 创建一个空的动作数据张量
                cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)  
    
                 # 创建一个全为 False 的掩码张量
                cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)         
    如果不将观察作为全局条件,则通过填充的方式进行条件处理,即使用 dict_apply 函数对观察数据进行处理,并通过观察编码器提取观察特征
            else:  
                # 通过填充进行条件处理
                this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))  # 对观察数据进行处理
    
                nobs_features = self.obs_encoder(this_nobs)  # 提取观察特征
    将提取的观察特征调整形状为 `(B, To, -1)`,并将其赋值给 cond_data 的相应部分,同时更新 cond_mask
                # 重新调整形状为 (B, T, Do)
                # 将观察特征调整形状为 (B, To, -1)
                nobs_features = nobs_features.reshape(B, To, -1)  
    
                # 创建一个空的动作数据张量
                cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
      
                # 创建一个全为 False 的掩码张量
                cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) 
    
                # 将观察特征赋值给动作数据张量的相应部分 
                cond_data[:,:To,Da:] = nobs_features      
                # 更新掩码张量
                cond_mask[:,:To,Da:] = True               
  3. 接下来,调用 conditional_sample 方法进行采样,传入动作数据、掩码、局部条件和全局条件等参数
            # 运行采样
             # 调用 conditional_sample 方法进行采样
            nsample = self.conditional_sample( 
                cond_data,          # 动作数据
                cond_mask,          # 掩码
                local_cond=local_cond,        # 局部条件
                global_cond=global_cond,      # 全局条件
                **self.kwargs)      # 其他关键字参数
    采样完成后,对预测的动作数据进行反归一化处理
            # 反归一化预测
            naction_pred = nsample[...,:Da]          # 获取预测的动作数据
            action_pred = self.normalizer['action'].unnormalize(naction_pred)  # 对预测的动作数据进行反归一化处理
  4. 最后,从预测的动作数据中提取所需的动作步数
            # 获取动作
            start = To - 1          # 设置起始步数
            end = start + self.n_action_steps      # 设置结束步数
            action = action_pred[:,start:end]      # 从预测的动作数据中提取所需的动作步数
    并返回最终的动作和动作预测结果
            # 获取预测结果
            result = {
                'action': action,                # 动作
                'action_pred': action_pred,      # 动作预测
            }
            
            return result         # 返回最终的动作和动作预测结果

通过这些步骤,predict_action 方法实现了从输入观察数据生成动作的过程,适用于基于扩散模型的点云策略

3.3.5 compute_loss:计算给定批次数据的损失

set_normalizer 方法用于设置归一化器 normalizer,通过加载归一化器的状态字典实现

compute_loss 方法用于计算给定批次数据的损失

  1. 首先,对输入进行归一化处理,并根据配置处理点云和颜色信息
  2. 然后,构建输入数据,包括局部和全局条件。生成掩码,并添加噪声到轨迹中。应用条件数据,预测模型输出,并根据调度器的配置计算目标
  3. 最后,计算均方误差损失,并返回损失值和损失字典

总的来说,DiffusionPointcloudPolicy 类通过扩散模型和条件采样,实现了基于点云的策略生成和训练

至于iDP3的部署、训练、预处理请见此文iDP3的训练与部署代码解析:从数据可视化vis_dataset.py、训练脚本train.py到部署脚本deploy.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

v_JULY_v

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值