Shelving, Stacking, Hanging: Relational Pose Diffusion for Multi-modal Rearrangement(三)—— 训练代码解析
文章概括
引用:
@inproceedings{simeonov2023shelving,
title={Shelving, Stacking, Hanging: Relational Pose Diffusion for Multi-modal Rearrangement},
author={Simeonov, Anthony and Goyal, Ankit and Manuelli, Lucas and Lin, Yen-Chen and Sarmiento, Alina and Garcia, Alberto Rodriguez and Agrawal, Pulkit and Fox, Dieter},
booktitle={Conference on Robot Learning},
pages={2030--2069},
year={2023},
organization={PMLR}
}
Simeonov, A., Goyal, A., Manuelli, L., Lin, Y.C., Sarmiento, A., Garcia, A.R., Agrawal, P. and Fox, D., 2023, December. Shelving, Stacking, Hanging: Relational Pose Diffusion for Multi-modal Rearrangement. In Conference on Robot Learning (pp. 2030-2069). PMLR.
主页:https://anthonysimeonov.github.io/rpdiff-multi-modal/
原文: https://arxiv.org/abs/2307.04751
代码、数据和视频: https://github.com/anthonysimeonov/rpdiff
系列文章:
请在
《
《
《文章
》
》
》 专栏中查找
宇宙声明!
引用解析部分属于自我理解补充,如有错误可以评论讨论然后改正!
摘要
我们提出了一个系统,用于重新排列场景中的物体,以实现所需的物体–场景放置关系,例如将书插入书架的一个开放槽中。 该流程能够推广到新的几何形状、姿态以及场景和物体的布局,并通过示范进行训练,直接在3D点云上运行。我们的系统克服了给定场景存在许多在几何上相似的重排解决方案所带来的挑战。通过利用迭代姿态去噪训练过程,我们能够拟合多模态示范数据并产生多模态输出,同时保持精确和准确。我们还展示了在条件于相关的局部几何特征、同时忽略那些损害泛化性和精确性的不相关全局结构时的优势。我们在仿真和现实世界中,通过三个不同的重排任务验证了我们的方法,这些任务都需要处理多模态并在物体形状与姿态上实现泛化。项目网站、代码和视频: https://anthonysimeonov.github.io/rpdiff-multi-modal
关键词: 物体重排、多模态、操作、点云
train_full.py 代码解析(目录:rpdiff/src/rpdiff/training/train_full.py)
if name == “main”:
if __name__ == "__main__":
-
作用
判断当前模块是否作为“主程序”被直接运行。 -
为什么要这么写
避免他人在别的脚本里import train_full
时,不必要地执行后续初始化逻辑。 -
具体示例
- 直接在命令行执行
python train_full.py
时,__name__
为"__main__"
,条件成立,执行后续解析、训练逻辑。 - 若在
other.py
中写import train_full
,此处不执行。
- 直接在命令行执行
"""Parse input arguments"""
- 作用
这是一个独立的字符串字面量(docstring),因不在函数或模块顶层位置,对程序行为无影响,仅作注释。 - 为什么要这么写
标明下面几行是在“解析输入参数”,方便阅读。
parser = argparse.ArgumentParser()
- 作用
创建一个ArgumentParser
对象,用于定义和解析命令行参数。 - 为什么要这么写
Python 标准库中推荐的命令行参数解析方式,易于扩展和生成帮助信息。 - 示例
parser
此时为空,准备后续通过add_argument
添加各种参数。
parser.add_argument('-c', '--config_fname', type=str, required=True, help='Name of config file')
-
作用
定义一个必须参数-c
或--config_fname
,接受字符串,用于指定配置文件名。- 短写:
-c
文件路径;是 “短选项(short flag)” - 长写:
--config_fname
文件路径;是 “长选项(long flag)” - 两者指向同一个参数。当你在命令行中写
-c experiment1.yaml
或--config_fname experiment1.yaml
,程序就会把 “experiment1.yaml” 赋值给 args.config_fname。
- 短写:
-
为什么要这么写
训练脚本往往依赖一个配置文件(如 YAML/JSON)来加载超参等。设为required=True
强制用户提供。 -
示例
在命令行输入-c experiment1.yaml
,则解析后args.config_fname == "experiment1.yaml"
。
parser.add_argument('-d', '--debug', action='store_true', help='If True, run in debug mode')
-
作用
定义一个可选布尔开关-d/--debug
。若出现此标志,args.debug=True
,否则为False
。- 默认值:
args.debug
会被初始化为False
- 如果在命令行出现
-d/--debug
:则args.debug
会被设成True
- 默认值:
-
为什么要这么写
方便在调试时打开更详细的日志、可视化或数据校验。 -
示例
python train.py -c cfg.yaml -d
→args.debug == True
python train.py -c cfg.yaml
→args.debug == False
parser.add_argument('-dd', '--debug_data', action='store_true', help='If True, run data loader in debug mode')
- 作用
同上,但专门控制数据加载部分的“调试模式”(例如只加载少量样本)。 - 为什么要这么写
数据加载可能是瓶颈,调试时只想快速跑过几条数据。 - 示例
--debug_data
打开后,args.debug_data == True
,数据加载模块中可判定为“只取 16 条而非全部”。
parser.add_argument('-p', '--port_vis', type=int, default=6000, help='Port for ZMQ url (meshcat visualization)')
-
作用
定义一个整数型可选参数-p/--port_vis
(缺省 6000),指定 MeshCat 可视化服务器的端口。 -
为什么要这么写
MeshCat 默认端口是 6000,但用户可以通过命令行改成 7000、8000 等,避免端口冲突。 -
示例
- 默认:
args.port_vis == 6000
- 用户指定
-p 7000
→args.port_vis == 7000
- 默认:
parser.add_argument('-s', '--seed', type=int, default=0, help='Random seed')
- 作用
接收一个整数随机种子参数,缺省为 0。 - 为什么要这么写
实验可复现性需要固定随机数种子,且脚本运行时传不同种子可做多次实验对比。 - 示例
--seed 42
→args.seed == 42
,训练和数据加载中所有random.seed(42)
,np.random.seed(42)
。
parser.add_argument('-l', '--local_dataset_dir', type=str, default=None,
help='If the data is saved on a local drive, pass the root of that directory here')
- 作用
可选地指定本地数据集所在根目录,若不传则使用配置文件中的默认路径。 - 为什么要这么写
当数据下载在网络存储或集群不同节点时,用户可显式指定本地缓存目录。 - 示例
--local_dataset_dir /mnt/data/datasets
→args.local_dataset_dir == "/mnt/data/datasets"
parser.add_argument('-r', '--resume', action='store_true',
help='If set, resume experiment (required to be set in config as well)')
- 作用
布尔开关,若设置则脚本尝试从上一次中断的 checkpoint 恢复训练。 - 为什么要这么写
长周期训练易被打断,提供自动恢复能力。 - 示例
--resume
→args.resume == True
,后续代码中会读取train_args['resume']
来决定加载 checkpoint。
parser.add_argument('-m', '--meshcat', action='store_true',
help='If set, run with meshcat visualization (required to be set in config as well)')
- 作用
布尔开关,若打开则脚本内部启用 MeshCat 实时三维可视化。 - 为什么要这么写
可选择性地打开或关闭可视化,节省无视觉需求时的资源。 - 示例
--meshcat
→args.meshcat == True
args = parser.parse_args()
-
作用
解析命令行传入的所有参数,将结果存入args
(一个Namespace
对象)。 -
为什么要这么写
至此,所有用户输入都被标准化到args.xxxx
中,方便后续使用。 -
示例
若执行python train_full.py -c cfg.yaml -d -p 7000
,则args.config_fname == "cfg.yaml" args.debug == True args.port_vis == 7000 args.seed == 0 # 默认 args.resume == False # 默认
train_args = config_util.load_config(
osp.join(path_util.get_train_config_dir(), args.config_fname)
)
-
作用
path_util.get_train_config_dir()
返回项目中“训练配置文件”所在文件夹路径(如"/home/user/project/configs/train"
)。osp.join(...)
拼出完整文件路径,例如"/home/user/project/configs/train/exp1.yaml"
。config_util.load_config(path)
读取该 YAML/JSON 文件并解析成 Python 字典dict
。
-
为什么要这么写
解耦文件夹位置与文件名,集中管理路径;同时统一从配置文件加载超参。 -
具体示例
-
get_train_config_dir()
→"/project/configs/train"
-
args.config_fname = "demo.yaml"
-
最终
load_config("/project/configs/train/demo.yaml")
→{ "model": "resnet18", "lr": 0.001, "epochs": 100, # … 其他超参 … }
-
train_args['debug'] = args.debug
- 作用
将命令行解析出的布尔debug
标志写回配置字典中,覆盖配置文件中同名字段。 - 为什么要这么写
使得配置中的debug
与命令行保持一致——命令行优先级更高。 - 示例
若args.debug == True
,则train_args["debug"] = True
。
train_args['debug_data'] = args.debug_data
train_args['port_vis'] = args.port_vis
train_args['seed'] = args.seed
train_args['local_dataset_dir'] = args.local_dataset_dir
train_args['meshcat_ap'] = args.meshcat
train_args['resume_ap'] = args.resume
- 作用
同上,把所有命令行参数映射并覆盖到配置字典train_args
中。 - 为什么要这么写
统一在一个train_args
结构里管理所有运行时参数(无论是配置文件还是命令行传入)。 - 示例
如果用户未传--local_dataset_dir
,则args.local_dataset_dir is None
,train_args['local_dataset_dir']
也置为None
,此后代码会用配置文件或默认逻辑自动选择数据路径。
train_args = config_util.recursive_attr_dict(train_args)
-
作用
把普通字典(dict
)转换成“属性字典”/“AttrDict”,可以通过train_args.debug
、train_args.lr
访问,而不仅限于train_args['debug']
。 -
为什么要这么写
- 阅读性好:
args.seed
比args['seed']
简洁。 - 统一类型:函数参数
main
中可假定train_args
支持属性访问。
- 阅读性好:
-
示例
- 之前:
train_args['epochs'] == 100
- 之后:
train_args.epochs == 100
- 之前:
main(train_args)
-
作用
调用主流程函数main
,传入所有整理好的运行参数train_args
。 -
为什么要这么写
把脚本的核心逻辑封装在main
函数里,便于测试、复用和导入。 -
输入
-
train_args
:一个支持属性访问的参数对象,字段示例有train_args.config_fname # str,如 "demo.yaml" train_args.debug # bool train_args.lr # float,学习率 train_args.model # str,如 "resnet18" train_args.local_dataset_dir # Optional[str] # … 等等 …
-
核心大函数/类汇总
名称 | 作用 | 输入 | 输出/返回值 | 常见接口/方法 |
---|---|---|---|---|
path_util.get_train_config_dir() | 返回训练配置文件目录路径 | 无 | str ,如 "/path/to/configs/train" | 可扩展:set_train_config_dir(path: str) |
config_util.load_config(path) | 从 YAML/JSON 配置文件加载参数到 dict | path: str | Dict[str, Any] | save_config(dict, path) |
config_util.recursive_attr_dict(d) | 将普通 dict 递归地转为支持 .attr 访问的结构 | d: Dict[str, Any] | AttrDict | to_dict() , 支持嵌套 |
main(train_args) | 脚本主流程:数据加载、模型构建、训练、可视化、保存等 | train_args: AttrDict | None | Trainer , Evaluator , Visualizer 等类协作调用 |
def main(args: config_util.AttrDict):
-
用途:
main
是脚本的核心入口,负责:- 设置随机种子,保证可复现性;
- 搭建实验目录结构,记录本次运行的代码、配置、日志;
- 加载数据集,构建训练/验证数据管道;
- 初始化各子网络(粗糙可用性、位姿微调、成功分类),以及它们对应的损失函数和优化器;
- (可选)从 checkpoint 恢复训练状态;
- 最后调用
train
函数真正执行训练流程。
-
输入
-
args
:一个AttrDict
对象,属性来源于命令行参数和配置文件的合并,典型字段有:args.seed # 随机种子,例如 0 或 42 args.experiment # 实验相关子配置,包含 batch_size、num_iterations、logdir、resume_iter 等 args.data # 数据相关配置,包含 data_root、dataset_path、voxel_grid、rot_grid_bins 等 args.model # 三个子模型配置,包含 refine_pose、success、coarse_affordance 等 args.loss # 对应三个子模型的损失设置 args.optimizer # 对应三个子模型的优化器设置 args.meshcat_ap # 是否打开 meshcat 可视化(命令行 `--meshcat`) args.port_vis # meshcat 端口,例如 6000 args.local_dataset_dir # 本地数据根目录,例如 "/mnt/data/datasets" args.debug_data # 数据加载调试开关
-
-
输出
- 无返回值,但会在
logdir
下生成模型、日志、可视化结果,并在屏幕/TensorBoard 打印训练信息。
- 无返回值,但会在
1. 设置随机种子,保证可复现性
random.seed(args.seed)
- 作用:为 Python 标准库的
random
模块设置种子。 - 示例:若
args.seed = 42
,则random.seed(42)
;后续调用random.random()
会固定输出序列。 - 为什么要这么写:确保使用
random
生成的随机数(如数据打乱顺序)在每次运行中一致。
torch.manual_seed(args.seed)
- 作用:为 PyTorch CPU 和单卡 GPU 设置全局随机种子。
- 示例:同样用
42
,后续所有torch.randn()
、torch.randperm()
等操作可复现。 - 为什么要这么写:深度学习训练中的权重初始化、数据增强等都依赖随机,需要固定。
np.random.seed(args.seed)
- 作用:为 NumPy 随机数生成器设置种子。
- 示例:
np.random.seed(42)
后,np.random.rand(3)
每次均输出相同数组。 - 为什么要这么写:如果数据预处理或其它步骤中使用了 NumPy 随机,需保持一致性。
2. 搭建实验目录与日志记录
##############################################
# Setup basic experiment params
- 注释:标志“搭建基础实验参数”这一逻辑块的开始。
logdir = osp.join(
path_util.get_rpdiff_model_weights(),
args.experiment.logdir,
args.experiment.experiment_name)
-
作用:拼出主存储目录
logdir
,用于保存模型权重、TensorBoard 日志等。 -
示例:
path_util.get_rpdiff_model_weights()
→"/home/user/rpdiff_weights"
args.experiment.logdir = "logs"
args.experiment.experiment_name = "exp1"
- 最终
logdir = "/home/user/rpdiff_weights/logs/exp1"
。
-
为什么要这么写:统一管理模型和日志,便于查找和对比不同实验。
util.safe_makedirs(logdir)
- 作用:若
logdir
不存在,递归创建目录;若已存在则忽略错误。 - 示例:
"/home/user/rpdiff_weights/logs/exp1"
不存在时创建。 - 为什么要这么写:避免后续写文件时报错,同时保留已有内容。
# Set up experiment run/config logging
nowstr = datetime.datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
-
作用:获取当前时间字符串,用于区分多次运行。
-
示例:假设当前时间 2025-05-07 14:23:45,则
nowstr == "05-07-2025_14-23-45"
-
为什么要这么写:在同一
exp1
文件夹下,多次启动时用不同子文件夹保存。
run_logs = osp.join(logdir, 'run_logs')
- 作用:在主目录下拼出“运行日志”子目录路径。
- 示例:
"/home/user/.../exp1/run_logs"
。 - 为什么要这么写:将所有运行记录集中在
run_logs
里,清晰分层。
util.safe_makedirs(run_logs)
- 同上:递归创建
run_logs
。
run_log_folder = osp.join(run_logs, nowstr)
- 作用:拼出当前这次运行的唯一日志文件夹名。
- 示例:
"/home/.../exp1/run_logs/05-07-2025_14-23-45"
。 - 为什么要这么写:分次管理,避免文件冲突。
util.safe_makedirs(run_log_folder)
- 同上:创建本次运行日志目录。
# copy everything we would like to know about this run in the run log folder
for fn in os.listdir(os.getcwd()):
if not (fn.endswith('.py') or fn.endswith('.sh') or fn.endswith('.bash')):
continue
log_fn = osp.join(run_log_folder, fn)
shutil.copy(fn, log_fn)
- 作用:把当前工作目录下所有
.py
、.sh
、.bash
脚本文件复制到运行日志目录,用于版本跟踪。 - 示例:若项目根有
train.py
、utils.py
、run.sh
,则都拷贝过去。 - 为什么要这么写:确保当时的代码快照被保留,方便日后复现或排查。
full_cfg_dict = copy.deepcopy(config_util.recursive_dict(args))
-
作用:将属性字典
args
转回普通dict
,并做深拷贝。 -
示例:
full_cfg_dict = { "seed": 42, "experiment": {"logdir": "logs", "experiment_name": "exp1", ...}, "model": {...}, ... }
-
为什么要这么写:把合并后、最终运行时的所有配置保存为纯文本。
full_cfg_fname = osp.join(run_log_folder, 'full_exp_cfg.txt')
- 作用:拼出配置保存文件名。
- 示例:
".../05-07-2025_14-23-45/full_exp_cfg.txt"
。
json.dump(full_cfg_dict, open(full_cfg_fname, 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
-
作用:以 UTF-8、漂亮缩进格式将配置字典写入磁盘。
-
示例:文件内容类似:
{ "seed": 42, "experiment": { "logdir": "logs", "experiment_name": "exp1", ... }, ... }
-
为什么要这么写:直观查看与脚本外部参数结合后的完整配置,增强透明度。
3. 可视化(MeshCat)初始化
if args.experiment.meshcat_on and args.meshcat_ap:
-
作用:检查配置与命令行是否同时打开了 meshcat 可视化。
-
示例:
- 在
args.experiment.yaml
中有meshcat_on: true
,且命令行传了--meshcat
,条件成立。
- 在
-
为什么要这么写:双重开关确保用户有意图使用可视化,避免无意占用端口。
zmq_url=f'tcp://127.0.0.1:{args.port_vis}'
- 作用:按格式拼出 ZMQ 通信 URL。
- 示例:
"tcp://127.0.0.1:6000"
。
mc_vis = meshcat.Visualizer(zmq_url=zmq_url)
-
作用:创建一个 MeshCat 可视化客户端。
-
输入:
zmq_url
:ZMQ 监听地址;
-
输出:一个
meshcat.Visualizer
对象,可通过mc_vis['scene']
操作 3D 场景。 -
为什么要这么写:将过程中的点云、网格、坐标系等实时推送到浏览器查看。
-
常见接口:
mc_vis['scene'][<name>].set_object(...)
:添加模型;mc_vis['scene'].delete()
:清空场景。
mc_vis['scene'].delete()
- 作用:删除场景中已有对象,保证每次启动时画布是空白的。
- 为什么要这么写:避免重叠显示上次运行残留。
else:
mc_vis = None
- 作用:若关闭可视化,则将
mc_vis
设为None
,下游代码判断后跳过可视化调用。 - 为什么要这么写:统一接口,避免每次都判断两个标志。
# prepare dictionary for extra kwargs in train function
train_kwargs = {}
- 作用:预先创建一个空字典,后续若有“第二模型”或自定义参数则填入,最后传给
train(...)
。 - 为什么要这么写:保持
train
函数接口统一,方便可选输入。
4. 准备数据集与 DataLoader
##############################################
# Prepare dataset and dataloader
- 注释:开始数据加载模块。
data_args = args.data
- 作用:局部变量引用,减少后续多次
args.data
的冗余书写。 - 示例:
data_args.data_root = "shape_data"
,data_args.dataset_path = "task1"
。
if osp.exists(str(args.local_dataset_dir)):
- 作用:判断用户是否通过命令行指定了一个真实存在的本地数据目录。
- 示例:用户传
--local_dataset_dir /mnt/data/datasets
,若该目录存在,则为真。
dataset_path = osp.join(
args.local_dataset_dir,
data_args.data_root,
data_args.dataset_path)
- 作用:若本地目录存在,则从本地拼出二级路径。
- 示例:
"/mnt/data/datasets/shape_data/task1"
。
else:
dataset_path = osp.join(
path_util.get_rpdiff_data(),
data_args.data_root,
data_args.dataset_path)
- 作用:否则从默认网络/集群路径加载数据。
- 示例:
"/home/user/rpdiff_data/shape_data/task1"
。
assert osp.exists(dataset_path), f'Dataset path: {dataset_path} does not exist'
- 作用:若最终路径不存在,抛出错误并打印具体路径,帮助定位问题。
- 示例:若路径错了,就会报
AssertionError: Dataset path: … does not exist
。
train_dataset = dataio.FullRelationPointcloudPolicyDataset(
dataset_path,
data_args,
phase='train',
train_coarse_aff=args.experiment.train.train_coarse_aff,
train_refine_pose=args.experiment.train.train_refine_pose,
train_success=args.experiment.train.train_success,
mc_vis=mc_vis,
debug_viz=args.debug_data)
-
作用:实例化训练集对象。
-
输入:
dataset_path: str
,数据根目录;data_args: AttrDict
,包含voxel_grid
、rot_grid
等预处理参数;phase='train'
:区分 train/val,不同阶段做数据增强;- 三个布尔开关:是否返回粗可用性、位姿微调、成功标记标签;
mc_vis: meshcat.Visualizer or None
:可视化句柄;debug_viz: bool
:调试模式下每条数据只可视化少量样本。
-
输出:一个 PyTorch
Dataset
实例,支持__len__()
,__getitem__(idx)
接口。 -
为什么要这么写:封装数据逻辑,统一返回三类任务所需输入。
val_dataset = dataio.FullRelationPointcloudPolicyDataset(
dataset_path,
data_args,
phase='val',
train_coarse_aff=args.experiment.train.train_coarse_aff,
train_refine_pose=args.experiment.train.train_refine_pose,
train_success=args.experiment.train.train_success,
mc_vis=mc_vis,
debug_viz=args.debug_data)
- 同上,但
phase='val'
,通常不做随机增强、shuffle。
train_dataloader = DataLoader(
train_dataset,
batch_size=args.experiment.batch_size,
shuffle=True,
num_workers=args.experiment.num_train_workers,
drop_last=True)
-
作用:用 PyTorch 自带
DataLoader
包装训练集:batch_size
: 如32
shuffle=True
: 每 epoch 打乱顺序num_workers
: 并行加载线程数,如4
drop_last=True
: 丢弃最后一个不满 batch
-
输出:可迭代的
train_dataloader
,每次返回一批数据字典。
val_dataloader = DataLoader(
val_dataset,
batch_size=2,
num_workers=1,
shuffle=False,
drop_last=True)
- 作用:验证集一般 batch 较小(这里写死
2
),不开 shuffle,线程数少(1
),drop_last 保持整批。 - 示例:用于定期在训练过程中评估性能。
# grab some things we need for training
args.experiment.epochs = args.experiment.num_iterations / len(train_dataloader)
-
作用:自动计算训练周期(epoch)数:
epochs = 总迭代次数 每个 epoch 的 batch 数 \text{epochs} = \frac{\text{总迭代次数}}{\text{每个 epoch 的 batch 数}} epochs=每个 epoch 的 batch 数总迭代次数
-
示例:若
num_iterations = 10000
且len(train_dataloader)=312
,则a r g s . e x p e r i m e n t . e p o c h s ≈ 10000 / 312 ≈ 32.05 args.experiment.epochs \approx 10000/312 \approx 32.05 args.experiment.epochs≈10000/312≈32.05
-
为什么要这么写:方便后续按 epoch 进行学习率调度或日志记录。
args.data.rot_grid_bins = train_dataset.rot_grid.shape[0]
- 作用:将数据集实例内部生成的旋转网格维度输出回配置,用于后续模型初始化。
- 示例:若
rot_grid
是(24, 3)
数组,则rot_grid_bins = 24
。 - 为什么要这么写:保持数据和模型对齐,避免手动同步出错。
5. 初始化三个子网络、损失与优化器
5.1 位姿微调模型(Pose Refinement)
pr_type = args.model.refine_pose.type
- 作用:读取配置中要用的微调模型类型字符串,如
"nsm_transformer"
。 - 示例:
pr_type = "nsm_transformer"
。
pr_args = config_util.copy_attr_dict(args.model[pr_type])
- 作用:复制该模型类型对应的参数子字典,如
args.model['nsm_transformer']
,用于初始化。 - 示例:
pr_args = {"num_layers":4, "hidden_dim":128, ...}
。
if args.model.refine_pose.get('model_kwargs') is not None:
custom_pr_args = args.model.refine_pose.model_kwargs[pr_type]
config_util.update_recursive(pr_args, custom_pr_args)
-
作用:若在
args.model.refine_pose.model_kwargs
中有针对不同pr_type
的自定义参数,则递归地合并覆盖。 -
示例:
model: refine_pose: type: nsm_transformer model_kwargs: nsm_transformer: hidden_dim: 256
上面就把
hidden_dim
从默认128
更新为256
。 -
为什么要这么写:支持一次配置文件中同时保存“通用参数”与“某模型专属参数”。
if pr_type == 'nsm_transformer':
pr_model_cls = NSMTransformerSingleTransformationRegression
elif pr_type == 'nsm_transformer_cvae':
pr_model_cls = NSMTransformerSingleTransformationRegressionCVAE
else:
raise ValueError(f'Unrecognized: {pr_type}')
-
作用:根据字符串选择具体的模型类。
-
示例:若
"nsm_transformer"
,则pr_model_cls == NSMTransformerSingleTransformationRegression
-
为什么要这么写:避免硬编码,仅靠 config 决定模型类型。
-
常见接口(类构造函数):
pr_model_cls(mc_vis, feat_dim, num_layers, hidden_dim, ...)
返回一个
torch.nn.Module
子类实例。
pr_model = pr_model_cls(
mc_vis=mc_vis,
feat_dim=args.model.refine_pose.feat_dim,
**pr_args).cuda()
-
作用:实例化微调模型并
.cuda()
转到 GPU。 -
示例:
pr_model = NSMTransformerSingleTransformationRegression( mc_vis=None, feat_dim=64, num_layers=4, hidden_dim=128 ).cuda()
-
为什么要这么写:GPU 加速计算。
-
输出:
pr_model
,含forward(...)
接口,接收点云 + 初始变换,输出细化后的位姿预测。
pr_model_params = pr_model.parameters()
- 作用:收集模型所有可训练参数,用于优化器。
- 示例:生成一个 Python
generator
,遍历后是若干torch.Tensor
。
# loss
pr_loss_type = args.loss.refine_pose.type
- 作用:读取微调模型对应的损失函数类型字符串,如
"tf_chamfer"
。
assert pr_loss_type in args.loss.refine_pose.valid_losses, \
f'Loss type: {pr_loss_type} not in {args.loss.refine_pose.valid_losses}'
- 作用:校验配置合法性,如果不在支持列表中则报错。
- 示例:若
valid_losses = ['tf_chamfer', 'tf_chamfer_w_kldiv']
,并且pr_loss_type='mse'
,此处就会断言失败。
if pr_loss_type == 'tf_chamfer':
tfc_mqa_wrapper = losses.TransformChamferWrapper(
l1=args.loss.tf_chamfer.l1,
trans_offset=args.loss.tf_chamfer.trans_offset)
pr_loss_fn = tfc_mqa_wrapper.tf_chamfer
-
作用:
-
实例化一个
TransformChamferWrapper
,包装 Chamfer 距离损失。l1
: 是否加 L1 位置误差分量,布尔或系数,例如True
或0.1
trans_offset
: 平移偏移惩罚系数,如0.01
-
从包装器中取出纯函数
tf_chamfer
作为损失函数。
-
-
输出:
pr_loss_fn
,接口签名类似loss = pr_loss_fn(pred_transforms, target_transforms)
-
为什么要这么写:实现多种损失切换,并复用包装器中预处理逻辑。
elif pr_loss_type == 'tf_chamfer_w_kldiv':
tfc_mqa_wrapper = losses.TransformChamferWrapper(
l1=args.loss.tf_chamfer.l1,
trans_offset=args.loss.tf_chamfer.trans_offset,
kl_div=True)
pr_loss_fn = tfc_mqa_wrapper.tf_chamfer_w_kldiv
- 作用:同上,但在 wrapper 中打开
kl_div
标志,生成带 KLDivergence 的损失函数。
else:
raise ValueError(f'Unrecognized: {pr_loss_fn}')
- 防御性编程:若意外遇到未知类型,报错提醒配置错误。
# optimizer
pr_opt_type = args.optimizer.refine_pose.type
assert pr_opt_type in args.optimizer.refine_pose.valid_opts, \
f'Opt type: {pr_opt_type} not in {args.optimizer.refine_pose.valid_opt}'
- 作用:类似损失,读取并校验优化器类型(如
"Adam"
、"AdamW"
)。
if pr_opt_type == 'Adam':
pr_opt_cls = torch.optim.Adam
elif pr_opt_type == 'AdamW':
pr_opt_cls = torch.optim.AdamW
else:
raise ValueError(f'Unrecognized: {pr_opt_type}')
- 作用:根据字符串映射到具体 PyTorch 优化器类。
pr_opt_kwargs = config_util.copy_attr_dict(args.optimizer[pr_opt_type])
- 作用:读取通用优化器超参,如
{"lr":0.001, "weight_decay":1e-4}
。
if args.optimizer.refine_pose.get('opt_kwargs') is not None:
custom_pr_opt_kwargs = args.optimizer.refine_pose.opt_kwargs[pr_opt_type]
config_util.update_recursive(pr_opt_kwargs, custom_pr_opt_kwargs)
- 作用:合并 refine_pose 专属的优化器参数覆盖通用设置。
pr_optimizer = pr_opt_cls(pr_model_params, **pr_opt_kwargs)
-
作用:生成微调模型的优化器实例。
-
示例:
pr_optimizer = torch.optim.Adam( pr_model.parameters(), lr=0.001, weight_decay=1e-4 )
-
为什么要这么写:后续在训练循环中调用
pr_optimizer.zero_grad()
、pr_optimizer.step()
。
5.2 成功分类器(Success Classifier)
整个流程与位姿模型几乎一致,这里只列要点。
sc_type = args.model.success.type
sc_args = config_util.copy_attr_dict(args.model[sc_type])
if args.model.success.get('model_kwargs') is not None:
custom_sc_args = args.model.success.model_kwargs[sc_type]
config_util.update_recursive(sc_args, custom_sc_args)
- 作用:读取、复制并可能合并成功分类模型的参数。
# model
if sc_type == 'nsm_transformer':
success_model_cls = NSMTransformerSingleSuccessClassifier
else:
raise ValueError(f'Unrecognized success model type: {sc_type}')
- 接口:
NSMTransformerSingleSuccessClassifier(mc_vis, feat_dim, **sc_args)
返回torch.nn.Module
,forward
输入点云与候选位姿,输出成功概率。
success_model = success_model_cls(
mc_vis=mc_vis,
feat_dim=args.model.success.feat_dim,
**sc_args).cuda()
sc_model_params = success_model.parameters()
- 示例:
feat_dim=64
,其他参数如num_heads=8
。
# loss
sc_loss_type = args.loss.success.type
assert sc_loss_type in args.loss.success.valid_losses, f'Loss type: {sc_loss_type} not in {args.loss.success.valid_losses}'
if sc_loss_type == 'bce_wo_logits':
sc_loss_fn = losses.success_bce
- 作用:若不带 logits 的二分类交叉熵,直接用封装的函数
success_bce(pred, label)
。
elif sc_loss_type == 'bce_w_logits':
double_batch = args.loss.bce_w_logits.double_batch_size
batch_scalar = 2 if double_batch else 1 # in some experiments, we double the batch size for the success model
bce_logits_wrapper = losses.BCEWithLogitsWrapper(
pos_weight=args.loss.bce_w_logits.pos_weight,
bs=args.experiment.batch_size*batch_scalar)
sc_loss_fn = bce_logits_wrapper.success_bce_w_logits
else:
raise ValueError(f'Unrecognized: {sc_loss_type}')
- 说明:带 logits 的版本,可指定
pos_weight
、是否加倍 batch,对正负样本不均衡情况友好。
# optimizer
sc_opt_type = args.optimizer.success.type
assert sc_opt_type in args.optimizer.success.valid_opts, ...
if sc_opt_type == 'Adam':
sc_opt_cls = torch.optim.Adam
elif sc_opt_type == 'AdamW':
sc_opt_cls = torch.optim.AdamW
else:
raise ValueError(f'Unrecognized: {sc_opt_type}')
sc_opt_kwargs = config_util.copy_attr_dict(args.optimizer[sc_opt_type])
if args.optimizer.success.get('opt_kwargs') is not None:
custom_sc_opt_kwargs = args.optimizer.success.opt_kwargs[sc_opt_type]
config_util.update_recursive(sc_opt_kwargs, custom_sc_opt_kwargs)
sc_optimizer = sc_opt_cls(sc_model_params, **sc_opt_kwargs)
- 作用:同位姿模型,初始化成功分类器的优化器。
5.3 粗糙可用性网络(Coarse Affordance)
# model
coarse_aff_type = args.model.coarse_aff.type
coarse_aff_args = config_util.copy_attr_dict(args.model[coarse_aff_type])
if args.model.coarse_aff.get('model_kwargs') is not None:
custom_coarse_aff_args = args.model.coarse_aff.model_kwargs[coarse_aff_type]
config_util.update_recursive(coarse_aff_args, custom_coarse_aff_args)
- 作用:加载 coarse affordance 网络配置。
if args.model.coarse_aff.multi_model:
coarse_aff_args2 = config_util.copy_attr_dict(args.model[coarse_aff_type])
if args.model.coarse_aff.get('model_kwargs2') is not None:
custom_coarse_aff_args2 = args.model.coarse_aff.model_kwargs2[coarse_aff_type]
config_util.update_recursive(coarse_aff_args2, custom_coarse_aff_args2)
- 说明:若要启用“双模型”模式,则对第二个模型做同样配置,后面会并行训练。
coarse_aff_model = CoarseAffordanceVoxelRot(
mc_vis=mc_vis,
feat_dim=args.model.coarse_aff.feat_dim,
rot_grid_dim=args.data.rot_grid_bins,
padding=args.data.voxel_grid.padding,
voxel_reso_grid=args.data.voxel_grid.reso_grid,
euler_rot=args.model.coarse_aff.euler_rot,
euler_bins_per_axis=args.model.coarse_aff.euler_bins_per_axis,
scene_encoder_kwargs=coarse_aff_args).cuda()
-
作用:初始化粗糙可用性网络并上 GPU。
-
输入:
feat_dim
: 特征维度,例如 64rot_grid_dim
: 旋转采样点数,如 24padding
、voxel_reso_grid
: 体素网格大小/分辨率,例如[32,32,32]
euler_rot
、euler_bins_per_axis
: 是否启用欧拉旋转采样及每轴 bins 数scene_encoder_kwargs
: 上面合并后的网络超参字典
-
输出:
torch.nn.Module
,forward(pointcloud)
返回每个体素的 affordance 概率。
aff_model_params = coarse_aff_model.parameters()
- 作用:收集参数。
# loss
aff_loss_type = args.loss.coarse_aff.type
assert aff_loss_type in args.loss.coarse_aff.valid_losses, f'Loss type: {aff_loss_type} not in {args.loss.coarse_aff.valid_losses}'
if aff_loss_type == 'voxel_affordance':
aff_loss_fn = losses.voxel_affordance
elif aff_loss_type == 'voxel_affordance_w_disc_rot':
aff_loss_fn = losses.voxel_affordance_w_disc_rot
elif aff_loss_type == 'voxel_affordance_w_disc_rot_euler':
aff_loss_fn = losses.voxel_affordance_w_disc_rot_euler
else:
raise ValueError(f'Unrecognized: {aff_loss_type}')
- 作用:根据类型切换三种可用性损失函数,分别是基础版本、带离散旋转惩罚、带欧拉旋转惩罚。
# optimizer
aff_opt_type = args.optimizer.coarse_aff.type
assert aff_opt_type in args.optimizer.coarse_aff.valid_opts, f'Opt type: {aff_opt_type} not in {args.optimizer.coarse_aff.valid_opt}'
if aff_opt_type == 'AdamW':
aff_opt_cls = torch.optim.AdamW
elif aff_opt_type == 'Adam':
aff_opt_cls = torch.optim.Adam
else:
raise ValueError(f'Unrecognized: {aff_opt_type}')
aff_opt_kwargs = config_util.copy_attr_dict(args.optimizer[aff_opt_type])
if args.optimizer.coarse_aff.get('opt_kwargs') is not None:
custom_aff_opt_kwargs = args.optimizer.coarse_aff.opt_kwargs[aff_opt_type]
config_util.update_recursive(aff_opt_kwargs, custom_aff_opt_kwargs)
aff_optimizer = aff_opt_cls(aff_model_params, **aff_opt_kwargs)
- 作用:初始化 coarse affordance 的优化器。
if args.model.coarse_aff.multi_model:
coarse_aff_args2 = config_util.copy_attr_dict(args.model[coarse_aff_type])
if args.model.coarse_aff.get('model_kwargs2') is not None:
custom_coarse_aff_args2 = args.model.coarse_aff.model_kwargs2[coarse_aff_type]
config_util.update_recursive(coarse_aff_args2, custom_coarse_aff_args2)
voxel_reso_grid2 = args.data.voxel_grid.reso_grid
voxel_reso_grid2 = util.set_if_not_none(voxel_reso_grid2, args.model.coarse_aff.model2.voxel_grid.reso_grid)
padding2 = args.data.voxel_grid.padding
padding2 = util.set_if_not_none(padding2, args.data.voxel_grid.padding)
coarse_aff_model2 = CoarseAffordanceVoxelRot(
mc_vis=mc_vis,
feat_dim=args.model.coarse_aff.feat_dim,
rot_grid_dim=args.data.rot_grid_bins,
padding=padding2,
voxel_reso_grid=voxel_reso_grid2,
scene_encoder_kwargs=coarse_aff_args2).cuda()
aff_model_params2 = coarse_aff_model2.parameters()
aff_loss_fn2 = aff_loss_fn
aff_optimizer2 = aff_opt_cls(aff_model_params2, **aff_opt_kwargs)
train_kwargs['coarse_aff_model2'] = {}
train_kwargs['coarse_aff_model2']['model'] = coarse_aff_model2
train_kwargs['coarse_aff_model2']['opt'] = aff_optimizer2
train_kwargs['coarse_aff_model2']['loss_fn'] = aff_loss_fn2
train_kwargs['coarse_aff_model2']['reso_grid'] = voxel_reso_grid2
train_kwargs['coarse_aff_model2']['padding'] = padding2
# model_dict = dict(model=model, rot=rot_model)
# if args.debug:
# print('Coarse affordance model: ')
# print(coarse_aff_model)
# print('Refine pose model: ')
# print(pr_model)
# print('Success model: ')
# print(success_model)
- 作用:若双模型模式,将第二套模型、优化器、损失、分辨率等都放入
train_kwargs
,供后续train
函数使用。
6. 从 Checkpoint 恢复(可选)
##############################################
# Load checkpoints if resuming
- 注释:开始恢复逻辑。
if args.experiment.resume and args.resume_ap:
- 作用:仅在配置文件和命令行都要求
resume=True
时才执行。
# find the latest iteration
ckpts = [int(val.split('model_')[1].replace('.pth', ''))
for val in os.listdir(logdir)
if (val.endswith('.pth') and 'latest' not in val)]
- 作用:扫描
logdir
下所有命名为model_XXX.pth
文件,取出数值部分作为可恢复的迭代号列表。 - 示例:文件夹里有
model_1000.pth
,model_2000.pth
, 则ckpts = [1000, 2000]
。
args.experiment.resume_iter = max(ckpts)
- 作用:选择最大的迭代号作为恢复点。
- 示例:
resume_iter = 2000
。
if args.experiment.resume_iter != 0:
- 作用:若恢复迭代号不为 0,则真正加载对应 checkpoint。默认
resume_iter=0
表示从头开始。
print(f'Resuming at iteration: {args.experiment.resume_iter}')
- 作用:在屏幕打印恢复点,增强可见性。
model_path = osp.join(logdir, f'model_{args.experiment.resume_iter}.pth')
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
-
作用:
- 拼出 checkpoint 路径,如
.../model_2000.pth
; - 用 CPU 加载(先到 CPU 再
.cuda()
或直接在 GPU 上加载也可)。
- 拼出 checkpoint 路径,如
-
输出:
checkpoint
是一个 dict,典型键包括:{ 'refine_pose_model_state_dict': {...}, 'pr_optimizer_state_dict': {...}, 'success_model_state_dict': {...}, 'sc_optimizer_state_dict': {...}, 'coarse_aff_model_state_dict': {...}, 'aff_optimizer_state_dict': {...}, # 如双模型还会有 _state_dict2 }
if args.experiment.train.train_refine_pose:
pr_model.load_state_dict(checkpoint['refine_pose_model_state_dict'])
pr_optimizer.load_state_dict(checkpoint['pr_optimizer_state_dict'])
- 作用:若该任务开启,则分别恢复模型权重和优化器状态(包括动量、学习率调度信息等)。
if args.experiment.train.train_success:
success_model.load_state_dict(checkpoint['success_model_state_dict'])
sc_optimizer.load_state_dict(checkpoint['sc_optimizer_state_dict'])
- 同上:成功分类器恢复。
if args.experiment.train.train_coarse_aff:
coarse_aff_model.load_state_dict(checkpoint['coarse_aff_model_state_dict'])
aff_optimizer.load_state_dict(checkpoint['aff_optimizer_state_dict'])
if args.model.coarse_aff.multi_model:
coarse_aff_model2.load_state_dict(checkpoint['coarse_aff_model_state_dict2'])
aff_optimizer.load_state_dict(checkpoint['aff_optimizer_state_dict2'])
- 同上:粗糙可用性网络(以及第二模型)恢复。
7. TensorBoard 日志与设备选择
logger = SummaryWriter(logdir)
-
作用:创建 TensorBoard 日志写入器,所有
logger.add_scalar(...)
会保存在该目录下。 -
输入:
logdir
,例如"/home/.../exp1"
-
输出:
logger
对象,常见接口:logger.add_scalar('loss/train', loss_value, global_step) logger.add_image('viz/pointcloud', img_tensor, global_step)
it = args.experiment.resume_iter
- 作用:把恢复的迭代号赋给
it
,作为训练循环的起始步数。 - 示例:若恢复到 2000,则从
it=2000
开始继续计数。
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
raise ValueError('Cuda not available')
-
作用:
- 检查是否有 GPU;
- 若有,则将
device
设为第一张卡cuda:0
; - 否则报错。
-
为什么要这么写:脚本依赖 GPU 训练,若无 GPU 需立刻停止。
8. 配置合法性检查
##############################################
# Perform other checks to ensure experiment config is valid
- 注释:开始一些互斥/依赖参数的断言检查。
train_exp_args = config_util.copy_attr_dict(args.experiment.train)
- 作用:复制一份训练子配置到局部变量中,简化书写。
if train_exp_args.refine_pose_from_coarse_pred:
assert train_exp_args.train_coarse_aff and train_exp_args.train_refine_pose, \
'Must be training both coarse and refine to use predictions as refinement input'
- 作用:若想用粗糙网络的输出作为微调网络输入,则必须同时训练两者。
assert not (train_exp_args.success_from_coarse_pred and train_exp_args.success_from_refine_pred), \
'Cannot predict success from both refine pred and coarse pred. Please only set one of these to True'
- 作用:不能同时使用两种预测结果来判断成功,否则逻辑冲突。
if train_exp_args.success_from_coarse_pred:
assert train_exp_args.train_coarse_aff and train_exp_args.train_success, \
'Must be training both coarse and success to use predictions as success input'
- 作用:依赖检查。
if train_exp_args.success_from_refine_pred:
assert train_exp_args.train_refine_pose and train_exp_args.train_success, \
'Must be training both refine and success to use predictions as success input'
- 作用:同上。
9. 启动训练
train(
mc_vis,
coarse_aff_model, pr_model, success_model,
aff_optimizer, pr_optimizer, sc_optimizer,
train_dataloader, val_dataloader,
aff_loss_fn, pr_loss_fn, sc_loss_fn,
device,
logger,
logdir,
args,
it,
**train_kwargs)
-
作用:调用训练主循环函数
train
,传入所有模型、优化器、数据管道、损失函数、设备、日志、初始迭代号,以及可选的第二模型等额外参数。 -
输入:
mc_vis
: MeshCat 可视化器或None
coarse_aff_model, pr_model, success_model
: 三个torch.nn.Module
aff_optimizer, pr_optimizer, sc_optimizer
: 三个优化器train_dataloader, val_dataloader
aff_loss_fn, pr_loss_fn, sc_loss_fn
: 三个损失函数device
:torch.device('cuda:0')
logger
:SummaryWriter
logdir
: 主目录字符串args
: 完整配置it
: 起始迭代号**train_kwargs
: 如多模型时的第二模型字典
-
输出:无返回值,主要在内部执行:
-
for epoch in range(...)
:-
for batch in train_dataloader
:- 依次调用三段网络前向、计算损失、反向传播、优化器
step()
; - 根据配置在每隔 N 步或每个 epoch 执行验证、可视化、保存 checkpoint。
- 依次调用三段网络前向、计算损失、反向传播、优化器
-
-
-
为什么要这么写:将所有准备好的资源一并传入训练函数,职责分离——
main
只做准备,train
专注运行。
def train( … ):
def train(
mc_vis: meshcat.Visualizer,
coarse_aff_model: nn.Module, refine_pose_model: nn.Module, success_model: nn.Module,
aff_optimizer: Optimizer, pr_optimizer: Optimizer, sc_optimizer: Optimizer,
train_dataloader: DataLoader, test_dataloader: DataLoader,
aff_loss_fn: Callable, pr_loss_fn: Callable, sc_loss_fn: Callable,
dev: torch.device,
logger: SummaryWriter,
logdir: str,
args: config_util.AttrDict,
start_iter: int=0,
**kwargs):
-
作用:定义训练主循环函数,接收可视化句柄、三大模型及其优化器、数据加载管道、三种损失函数、设备、日志写入器、输出目录、配置参数、起始迭代数,以及可选的额外模型(
kwargs
)。 -
设计原因:职责分明——
main
函数准备好所有资源后,一次性传给train
,train
专注于迭代流程。 -
接口
mc_vis
:或None
,用于实时 3D 可视化;coarse_aff_model
/refine_pose_model
/success_model
:继承自nn.Module
,需实现forward(...)
;*_optimizer
:标准torch.optim.Optimizer
,需实现.zero_grad()
,.step()
,.load_state_dict()
等;train_dataloader
/test_dataloader
:torch.utils.data.DataLoader
,迭代返回(coarse_aff_sample, refine_pose_sample, success_sample)
;*_loss_fn
:调用签名如(pred, gt)→loss_tensor
;dev
:torch.device('cuda:0')
;logger
:tensorboard.SummaryWriter
;logdir
:字符串,如"/home/.../exp1"
;args
:AttrDict
,包含args.experiment.batch_size
、args.experiment.num_iterations
等;start_iter
:整型,恢复训练时的初始步数,默认为 0;**kwargs
:如双模型时包含coarse_aff_model2
等。
1. 前期的配置
coarse_aff_model.train()
refine_pose_model.train()
success_model.train()
- 每行作用:将 PyTorch 模型切换到“训练模式”(
train()
),启用 BatchNorm 的更新、Dropout 生效等。 - 示例:若模型中含
nn.Dropout(p=0.5)
,此时在.train()
下,每个前向调用随机丢弃约 50% 神经元;在.eval()
则不丢弃。 - 为什么这么写:保证后续训练时的正则化层行为正确;如果忘记,会导致验证时与训练时分歧巨大。
offset = np.array([0, 0.2, 0.0])
- 作用:定义一个三维偏移向量,用于可视化或数据预处理时在 z 轴方向平移(例如把物体稍微抬高)。
- 示例:
offset = [0.0, 0.2, 0.0]
,表示在 y 方向上抬高 0.2 米。 - 为什么这么写:便于在 MeshCat 中把点云稍微抬高,避免与地面平面重叠导致难以辨识。
bs = args.experiment.batch_size
- 作用:局部变量保存 batch size,减少后面多次访问
args
的开销。 - 示例:若配置中
batch_size=32
,则bs = 32
。 - 为什么这么写:简洁易用,也清晰地将 batch size 引入局部作用域。
it = start_iter
- 作用:将函数参数
start_iter
(恢复时的迭代步数)赋给局部变量it
,用作全局迭代计数器。 - 示例:若恢复训练,
start_iter=2000
,则从第 2000 步开始;否则start_iter=0
。 - 为什么这么写:在训练循环里每处理一个 batch 都会
it += 1
。
voxel_grid_pts = torch.from_numpy(train_dataloader.dataset.raster_pts).float().cuda()
-
作用:
- 从训练集
Dataset
对象中读取预先生成的体素格点坐标数组raster_pts
,形状为(N,3)
,例如(32768,3)
; - 转为
torch.Tensor(dtype=float32)
; .cuda()
放到 GPU。
- 从训练集
-
示例:若
raster_pts
是(32³,3)
个网格点坐标,内容如[[ -0.2, -0.2, -0.2 ], [ -0.2, -0.2, -0.192 ], …]
。 -
为什么这么写:后续在前向或损失计算中需要这些网格点,可避免每个 batch 重复拷贝。
rot_mat_grid = torch.from_numpy(train_dataloader.dataset.rot_grid).float().cuda()
- 作用:同上,但读取旋转矩阵列表
rot_grid
,形状(R,3,3)
,例如(24,3,3)
,然后上 GPU。 - 示例:
rot_grid[0]
可能是绕 x 轴旋转 0°,rot_grid[1]
是绕 x 轴旋转 15°,等等。 - 为什么这么写:后续生成不同旋转角度下的点云候选,需要快速批量矩阵相乘。
args.experiment.dataset_length = len(train_dataloader.dataset)
- 作用:将数据集大小写回配置,供日志或学习率调度使用。
- 示例:若训练集有 10 000 条样本,则
dataset_length = 10000
。 - 为什么这么写:避免手动再去测
len(dataset)
,集中在args
中管理。
if 'coarse_aff_model2' in kwargs:
- 作用:检测是否启用了“双模型”训练模式(第二套 coarse affordance 模型),如果有,则初始化相关变量。
- 示例:当
kwargs={'coarse_aff_model2':{...}}
时进入。
coarse_aff_model2 = kwargs['coarse_aff_model2']['model']
aff_optimizer2 = kwargs['coarse_aff_model2']['opt']
aff_loss_fn2 = kwargs['coarse_aff_model2']['loss_fn']
reso_grid2 = kwargs['coarse_aff_model2']['reso_grid']
padding2 = kwargs['coarse_aff_model2']['padding']
-
每行作用:分别从
kwargs['coarse_aff_model2']
字典中取出第二模型实例、优化器、损失函数、体素分辨率和 padding。 -
示例:
reso_grid2 = 32
(即 32³ 网格);padding2 = 0.1
(米)。
-
为什么这么写:方便后续使用第二套模型的相同训练逻辑。
voxel_grid_pts2 = three_util.get_raster_points(reso_grid2, padding=padding2)
-
作用:调用工具函数
three_util.get_raster_points(n, padding)
,生成(n³,3)
格点在世界坐标系下的坐标。 -
输入:
reso_grid2=32
,padding2=0.1
,
-
输出:
numpy.ndarray
,形状(32768,3)
。 -
为什么这么写:第二模型可能有与第一模型不同的网格参数,需要单独生成。
# reshape to grid, and swap axes (permute x and z), B x reso x reso x reso x 3
voxel_grid_pts2 = voxel_grid_pts2.reshape(reso_grid2, reso_grid2, reso_grid2, 3)
voxel_grid_pts2 = voxel_grid_pts2.transpose(2, 1, 0, 3)
-
作用:
- 把
(N,3)
重塑为(reso,reso,reso,3)
的立方体索引形式; transpose(2,1,0,3)
将原来的(x,y,z)
维度调换成(z,y,x)
,以符合后续函数的排列习惯。
- 把
-
示例:若
reso_grid2=4
,voxel_grid_pts2.reshape(4,4,4,3)
。 -
为什么这么写:某些下游算法(如体素卷积)需要特定维度顺序。
# reshape back to B x N x 3
voxel_grid_pts2 = torch.from_numpy(voxel_grid_pts2.reshape(-1, 3)).float().cuda()
- 作用:把数据恢复成
(N,3)
格式并上 GPU,与voxel_grid_pts
一致。 - 示例:恢复成
(32768,3)
。 - 为什么这么写:按 Tensor 形式并行运算更高效。
coarse_aff_model2 = coarse_aff_model2.train()
- 作用:同第一模型,将第二模型切换到训练模式,同时把返回值(通常是自身)赋回。
- 为什么这么写:确保两套模型都在
train()
模式下。
2. 进入主训练循环
while True:
- 作用:开启一个无限循环,内部以迭代次数
it
控制退出。 - 为什么这么写:可灵活在中途
break
,并在循环内部集中做各种训练步骤。
if it > args.experiment.num_iterations:
break
- 作用:当全局迭代
it
超过配置的最大迭代数num_iterations
(如 100000)时,退出循环。 - 示例:若
num_iterations=50000
且it=50001
,则结束训练。 - 为什么这么写:控制训练总步数。
2.1 可选数据调试
if args.debug_data:
# sample = train_dataloader.dataset[0]
sample = train_dataloader.dataset[371]
# sample = train_dataloader.dataset[1963]
print('[Debug Data] Here with sample')
# for i in range(len(train_dataloader.dataset)):
# sample = train_dataloader.dataset[i]
# if 'parent_start_pcd' not in sample[1][0].keys():
# print(f'[Debug Data] Here with bad sample (index: {i})')
# from IPython import embed; embed()
from IPython import embed; embed()
- 作用:当开启
--debug_data
时,从数据集中取固定索引(371)样本,打印调试信息并进入 IPython 交互。 - 示例:这样可以在训练前检查某个样本是否有缺失字段,或可视化它。
- 为什么这么写:快速定位哪条数据有问题,避免整个训练流程卡在某个 batch 上。
for sample in train_dataloader:
- 作用:遍历
DataLoader
返回的每个 batch,sample
通常是三元组(coarse_aff_sample, refine_pose_sample, success_sample)
。 - 示例:若 batch size=32,则每次
sample
中每个子项的数据量为 32 条。 - 为什么这么写:标准 PyTorch 训练循环写法。
it += 1
- 作用:迭代步数自增 1,用于日志、学习率调度、Checkpoint 保存等。
- 示例:第一次循环时
it
从 0→1,第二次 1→2。 - 为什么这么写:准确记录训练进度。
current_epoch = it * bs / len(train_dataloader.dataset)
- 作用:动态计算当前“epoch”进度:已训练样本数
it*bs
除以总样本数。 - 示例:如果
it=156
、bs=32
、dataset_length=10000
,则current_epoch = 156*32/10000 ≈ 0.4992
,即接近第 0.5 个 epoch。 - 为什么这么写:便于在 TensorBoard 中横坐标以 epoch 为单位显示。
start_time = time.time()
- 作用:记录本次 batch 前的时间戳,用于计算本次前向+反向用时。
- 示例:
start_time = 1700000000.123
(Unix 时间)。 - 为什么这么写:性能监控,判断每个 batch 速度。
coarse_aff_sample, refine_pose_sample, success_sample = sample
-
作用:解包
sample
,分别得到三组任务数据。 -
示例:
coarse_aff_sample = (inputs_dict, gt_dict)
,refine_pose_sample = (inputs_dict, gt_dict)
,success_sample = (inputs_dict, gt_dict)
。
-
为什么这么写:三任务并行,可以选择性地训练。
coarse_aff_mi, coarse_aff_gt = coarse_aff_sample
refine_pose_mi, refine_pose_gt = refine_pose_sample
success_mi, success_gt = success_sample
-
作用:进一步解包——
mi
表示模型输入字典(model input),gt
表示 ground-truth 字典。 -
示例:
coarse_aff_mi = {'pointcloud':Tensor(32,1024,3), 'voxel_centers':Tensor(32768,3), …} coarse_aff_gt = {'aff_labels':Tensor(32,32768), …}
-
为什么这么写:保持统一命名,后续可写
dict_to_gpu(coarse_aff_mi)
。
coarse_aff_out = None
refine_pose_out = None
success_out = None
- 作用:先把输出占位为
None
,以便后面条件分支后检查是否生成。 - 为什么这么写:清晰标识三种输出可能未被某些分支产生。
loss_dict = {}
- 作用:初始化一个空字典,后续把各任务的损失值收集进来,用于一次性日志打印。
- 为什么这么写:统一管理不同子任务的损失输出。
2.2 训练粗可用性子任务
if args.experiment.train.train_coarse_aff and (len(coarse_aff_mi) > 0):
- 作用:仅当配置开启该子任务,且输入非空时才执行该块。
- 示例:
args.experiment.train.train_coarse_aff=True
,且coarse_aff_mi
包含若干点云,则进入。
# prepare input and gt
coarse_aff_mi = dict_to_gpu(coarse_aff_mi)
coarse_aff_gt = dict_to_gpu(coarse_aff_gt)
-
作用:调用工具函数
dict_to_gpu
,将字典中所有Tensor
从 CPU 移到 GPU。 -
接口:
def dict_to_gpu(d: Dict[str,Tensor]) -> Dict[str,Tensor]: return {k:v.cuda() for k,v in d.items()}
-
为什么这么写:模型在 GPU 上运算,输入数据必须位于同一设备。
coarse_aff_out = train_iter_coarse_aff(
coarse_aff_mi,
coarse_aff_gt,
coarse_aff_model,
aff_optimizer,
aff_loss_fn,
args,
voxel_grid_pts, args.data.voxel_grid.reso_grid,
rot_mat_grid, args.data.rot_grid_bins,
it, current_epoch,
logger,
mc_vis=mc_vis)
-
作用:调用“单次训练迭代”函数
train_iter_coarse_aff
,完成一次前向、损失、反向、优化并可视化。 -
接口:
def train_iter_coarse_aff( inputs:Dict, gt:Dict, model:nn.Module, optimizer:Optimizer, loss_fn:Callable, args:AttrDict, voxel_pts:Tensor, reso:int, rot_mats:Tensor, rot_bins:int, iter_num:int, epoch:float, logger:SummaryWriter, mc_vis:meshcat.Visualizer=None, refine_pred:bool=False ) -> Dict[str,Any]:
-
输入:如上;
-
输出:
coarse_aff_out
是字典,包含{ 'model_output': Tensor, # 原始网络预测 'loss': {'aff_loss': Tensor(...), ...}, 'viz_data': {...} # 可视化所需 }
-
-
为什么这么写:把核心训练逻辑封装,
train
保持清晰。
# process output for logging
for k, v in coarse_aff_out['loss'].items():
loss_dict[k] = v
- 作用:遍历该子任务返回的所有损失项(可能有主损失 + 辅助损失),收集到
loss_dict
,后续统一打印/写入 TensorBoard。 - 示例:可能有
{'aff_loss': Tensor(0.123), 'rot_loss': Tensor(0.045)}
。
2.2.1 结合粗预测生成 refine 输入
if args.experiment.train.coarse_aff_from_coarse_pred:
# process coarse prediction and refine input to create new refine input
- 作用:若配置要求用本次粗预测结果再做一次粗预测,则进入。
if args.model.coarse_aff.multi_model:
aff_refine_model = coarse_aff_model2
aff_refine_opt = aff_optimizer2
aff_refine_loss_fn = aff_loss_fn2
aff_refine_voxel_grid_pts = voxel_grid_pts2
aff_refine_reso_grid = reso_grid2
else:
aff_refine_model = coarse_aff_model
aff_refine_opt = aff_optimizer
aff_refine_loss_fn = aff_loss_fn
aff_refine_voxel_grid_pts = voxel_grid_pts
aff_refine_reso_grid = args.data.voxel_grid.reso_grid
- 作用:根据是否启用双模型,选择做 refine 的那一套模型、参数和网格。
coarse_aff_mi, coarse_aff_gt = coarse_aff_to_coarse_aff(
coarse_aff_mi, coarse_aff_out['model_output'], coarse_aff_gt,
rot_mat_grid, voxel_grid_pts, aff_refine_voxel_grid_pts, aff_refine_reso_grid, args, mc_vis=mc_vis)
-
作用:调用转换函数
coarse_aff_to_coarse_aff
,将粗预测结果model_output
结合原始输入/标签,生成新的输入/标签对,用于二次迭代。 -
接口:
def coarse_aff_to_coarse_aff( mi:Dict, pred:Tensor, gt:Dict, rot_mats:Tensor, voxel_pts1:Tensor, voxel_pts2:Tensor, reso2:int, args:AttrDict, mc_vis:Visualizer=None ) -> Tuple[Dict,Dict]: # 1. 根据 pred 生成新的体素网格预测标签 # 2. 构建新的 mi/gt 字典
-
为什么这么写:尝试多次迭代 refine,提升可用性预测精度。
coarse_aff_mi = dict_to_gpu(coarse_aff_mi)
coarse_aff_gt = dict_to_gpu(coarse_aff_gt)
- 作用:同上,把新生成的输入/标签搬到 GPU。
coarse_aff_out = train_iter_coarse_aff(
coarse_aff_mi,
coarse_aff_gt,
aff_refine_model,
aff_refine_opt,
aff_refine_loss_fn,
args,
aff_refine_voxel_grid_pts, aff_refine_reso_grid,
rot_mat_grid, args.data.rot_grid_bins,
it, current_epoch,
logger, refine_pred=True,
mc_vis=mc_vis)
- 作用:再跑一次
train_iter_coarse_aff
,并在输出中标记refine_pred=True
以区别首轮与二轮迭代。 - 为什么这么写:同上。
2.2.2 从粗预测生成 refine_pose 输入
if args.experiment.train.refine_pose_from_coarse_pred:
# process coarse prediction and refine input to create new refine input
refine_pose_mi, refine_pose_gt = coarse_aff_to_refine_pose(
coarse_aff_mi, coarse_aff_out['model_output'], coarse_aff_gt,
refine_pose_mi, refine_pose_gt,
rot_mat_grid, voxel_grid_pts, args, mc_vis=mc_vis)
-
作用:如果配置要求,用粗预测直接生成给微调网络的新输入/标签,调用
coarse_aff_to_refine_pose
。 -
接口:
def coarse_aff_to_refine_pose( coarse_mi, coarse_pred, coarse_gt, prev_refine_mi, prev_refine_gt, rot_mats, voxel_pts, args, mc_vis=None ) -> Tuple[new_refine_mi,new_refine_gt]: # 1. 从 coarse_pred 提取最可能的位姿候选 # 2. 更新 refine_pose 的输入/标签
-
为什么这么写:连贯地将两个任务串联,训练过程中让 refine 网络看到最新的粗预测。
2.2.3 从粗预测生成 success 输入
if args.experiment.train.success_from_coarse_pred:
# process coarse prediction and success input to create new success input
success_mi, success_gt = coarse_aff_to_success(
coarse_aff_mi, coarse_aff_out['model_output'], coarse_aff_gt,
success_mi, success_gt,
rot_mat_grid, voxel_grid_pts, args, mc_vis=mc_vis)
- 作用:同样地,用粗预测生成成功判断子任务的输入/标签,调用
coarse_aff_to_success
。 - 为什么这么写:通过粗预测结果辅助训练成功分类器。
2.3 训练微调(Refine Pose)子任务
if args.experiment.train.train_refine_pose and (len(refine_pose_mi) > 0):
- 作用:若打开该子任务且输入非空,则进入。
# prepare input and gt
refine_pose_mi = dict_to_gpu(refine_pose_mi)
refine_pose_gt = dict_to_gpu(refine_pose_gt)
- 作用:上 GPU。
refine_pose_out = train_iter_refine_pose(
refine_pose_mi,
refine_pose_gt,
refine_pose_model,
pr_optimizer,
pr_loss_fn,
args,
it, current_epoch,
logger,
mc_vis=mc_vis)
-
接口:
def train_iter_refine_pose( mi:Dict, gt:Dict, model:nn.Module, optimizer:Optimizer, loss_fn:Callable, args:AttrDict, iter_num:int, epoch:float, logger:SummaryWriter, mc_vis:Visualizer=None ) -> Dict[str,Any]: # 1. optimizer.zero_grad() # 2. pred = model(mi) # 3. loss = loss_fn(pred, gt) # 4. loss.backward() # 5. optimizer.step() # 6. logger.add_scalar... # 7. 可视化(mc_vis) # 8. return {'model_output':pred, 'loss':{'pose_loss':...}}
-
为什么这么写:同样封装迭代细节。
# process output for logging
for k, v in refine_pose_out['loss'].items():
loss_dict[k] = v
- 作用:收集该子任务的所有损失项到
loss_dict
。
if args.experiment.train.success_from_refine_pred:
# process coarse prediction and success input to create new success input
success_mi, success_gt = refine_pose_to_success(
refine_pose_mi, refine_pose_out['model_output'], refine_pose_gt,
success_mi, success_gt,
args, mc_vis=mc_vis)
- 作用:若开启“从 refine 输出预测 success”,调用
refine_pose_to_success
,生成 success 输入/标签。
2.4 训练成功分类(Success)子任务
if args.experiment.train.train_success and (len(success_mi) > 0):
- 作用:若配置开启且输入非空,则进入。
if args.experiment.train.success_from_refine_pred and (len(refine_pose_mi) == 0):
print(f'Skipping success due to no refine pose model input')
continue
- 作用:若配置要求从 refine 预测生成 success,但 refine 模块未提供输入(长度为 0),则跳过本 batch 的 success 部分,继续下一个 batch。
- 为什么这么写:防止空输入导致报错。
# prepare input and gt
success_mi = dict_to_gpu(success_mi)
success_gt = dict_to_gpu(success_gt)
- 作用:上 GPU。
success_out = train_iter_success(
success_mi,
success_gt,
success_model,
sc_optimizer,
sc_loss_fn,
args,
it, current_epoch,
logger,
mc_vis=mc_vis)
-
接口:
def train_iter_success( mi:Dict, gt:Dict, model:nn.Module, optimizer:Optimizer, loss_fn:Callable, args:AttrDict, iter_num:int, epoch:float, logger:SummaryWriter, mc_vis:Visualizer=None ) -> Dict[str,Any]: # 同上结构
-
为什么这么写:与前两者保持一致、易于维护。
# process output for logging
for k, v in success_out['loss'].items():
loss_dict[k] = v
- 作用:收集 success 子任务的损失项。
2.5 日志打印 与 Checkpoint 保存
if it % args.experiment.log_interval == 0 and args.experiment.train.out_log_full:
- 作用:每隔
log_interval
步(如 100 步)且开启完整日志输出时,执行打印和 TensorBoard 写入。 - 示例:若
log_interval=50
,当it=50,100,150…
时触发。
string = f'Iteration {it} -- '
- 作用:初始化日志字符串。
for loss_name, loss_val in loss_dict.items():
if isinstance(loss_val, dict):
# don't let these loss dicts get more than two levels deep
for k, v in loss_val.items():
string += f'{k}: {v.mean().item():.6f} '
logger.add_scalar(k, v.mean().item(), it)
else:
string += f'{loss_name}: {loss_val.mean().item():.6f} '
logger.add_scalar(loss_name, loss_val.mean().item(), it)
-
作用:
- 遍历
loss_dict
,支持两层嵌套(如某些子损失又返回字典) - 对每个张量
v
做.mean().item()
,得到 Python float,如0.123456
; - 拼入日志字符串;
- 用
logger.add_scalar(name, value, it)
写入 TensorBoard。
- 遍历
-
示例:
Iteration 100 -- aff_loss: 0.123456 pose_loss: 0.045678 success_loss: 0.210000
if args.experiment.debug:
from IPython import embed; embed()
- 作用:若开启 debug 模式,进入交互式终端,方便现场调试中断点。
- 为什么这么写:实用的调试手段。
end_time = time.time()
total_duration = end_time - start_time
- 作用:计算本次 batch 的总耗时,如
0.2345
秒。
string += f'duration: {total_duration:.4f}'
print(string)
-
作用:将耗时拼接进入日志并打印到屏幕,如
Iteration 100 -- aff_loss:0.123456 pose_loss:0.045678 success_loss:0.210000 duration:0.2345
if it % args.experiment.save_interval == 0 and it > 0:
- 作用:每隔
save_interval
步(如 1000)且it>0
时,保存模型 checkpoint。 - 示例:当
save_interval=500
且it=500,1000,1500…
时触发。
model_path = osp.join(logdir, f'model_{it}.pth')
model_path_latest = osp.join(logdir, 'model_latest.pth')
- 作用:拼出当前迭代号和“最新”两个文件名,前者用于版本管理,后者始终指向最新模型。
ckpt = {'args': config_util.recursive_dict(args)}
- 作用:新建 checkpoint 字典,先保存配置快照。
- 为什么这么写:确保重现训练环境。
ckpt['coarse_aff_model_state_dict'] = coarse_aff_model.state_dict()
ckpt['refine_pose_model_state_dict'] = refine_pose_model.state_dict()
ckpt['success_model_state_dict'] = success_model.state_dict()
- 作用:分别保存三大模型的权重字典(
dict[str,Tensor]
)。
ckpt['aff_optimizer_state_dict'] = aff_optimizer.state_dict()
ckpt['pr_optimizer_state_dict'] = pr_optimizer.state_dict()
ckpt['sc_optimizer_state_dict'] = sc_optimizer.state_dict()
- 作用:保存对应优化器的内部状态(包括动量、学习率调度器步数等)。
if args.model.coarse_aff.multi_model:
ckpt['coarse_aff_model_state_dict2'] = coarse_aff_model2.state_dict()
ckpt['aff_optimizer_state_dict2'] = aff_optimizer2.state_dict()
- 作用:若启用双模型,再保存第二模型及其优化器状态。
torch.save(ckpt, model_path)
torch.save(ckpt, model_path_latest)
- 作用:将
ckpt
序列化存盘,两个路径各写一份。 - 为什么这么写:保留历史版本同时更新“最新”链接,便于恢复最新或指定版本。