Dexcap复现代码运行逻辑全流程(N+1)——run_trained_agent.py

Dexcap源码中,Acknowledgements代表了部署通讯过程:

run_trained_agent.py 文件,是 robomimic 中明确的仿真 eval 脚本

因此,Dexcap 的部署,也和 run_trained_agent.py 有着密切关联

前置基础查看 robomimic应用教程(二)——策略运行与评估

目录

一、部署逻辑

1. 输入

2. 模型

3. 输出

二、代码解析

1. 库及模块导入

2. 定义辅助函数 rollout

2.1 参数

2.2 返回值

2.3 初始化

2.4 主循环

2.5 异常处理

2.6 回合结束处理

2.7 返回值

3. 定义主函数 run_trained_agent

3.1 参数

3.2 参数检查和设置

***3.3 加载策略检查点

3.4 设置回合参数

3.5 创建环境

3.6 设置随机种子

3.7 初始化视频和数据写入器

3.8 执行回合

3.9 回合结束处理

4. 解析命令行参数并执行脚本

4.1 命令行参数

4.2 代码运行


一、部署逻辑

有一个核心逻辑在于:dexcap 从相机获取输入,然后通过 diffusion policy生产动作轨迹,最后发送给机器人和末端执行器执行

1. 输入

pth 文件即扩散策略模型文件

2. 模型

训练得到 pth 文件,该文件包含了扩散策略模型的参数(权重和偏置)

3. 输出

 训练后的 policy 部署后会输出 46 维的动作空间,组成如下(按顺序):

  • 3维右臂末端执行器的平移: 控制右臂末端在空间中的位置移动

  • 3维左臂末端执行器的平移: 控制左臂末端在空间中的位置移动

  • 4维右臂末端执行器的四元数方向: 使用四元数表示右臂末端的旋转方向

  • 4维左臂末端执行器的四元数方向: 使用四元数表示左臂末端的旋转方向

  • 16维右手LEAP手部关节位置: 控制右手LEAP机械手的各个关节位置

  • 16维左手LEAP手部关节位置: 控制左手LEAP机械手的各个关节位置

目前只能通过真实的机器人来验证学习到的策略

机器人:利用输出的7维(3维平移 + 4维方向)数据进行末端执行器的位置控制

灵巧手:按照开源的 LEAPHAND 手册自行构建

手腕和手指的重新定位是分开的,首先执行末端执行器位置控制,让机器人达到从人类数据(手腕相机帧)捕获的手掌姿势,然后在手腕帧内执行指尖IK

二、代码解析

此脚本的主要功能是使用 robomimic 库在环境中评估训练好的 policy

支持在屏幕上渲染、录制视频、保存回放数据等功能

1. 库及模块导入

# 导入所需的库
import argparse  # 用于解析命令行参数
import os
import json
import h5py  # 用于处理HDF5文件格式
import imageio  # 用于读写图像和视频
import sys
import time
import traceback  # 用于获取异常的详细信息
import numpy as np
from copy import deepcopy  # 用于深拷贝对象
from tqdm import tqdm  # 用于显示进度条

import torch  # PyTorch库

# 导入robomimic库及其子模块
import robomimic
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.utils.log_utils import log_warning
from robomimic.envs.env_base import EnvBase
from robomimic.envs.wrappers import EnvWrapper
from robomimic.algo import RolloutPolicy
from robomimic.scripts.playback_dataset import DEFAULT_CAMERAS  # 默认摄像机配置

导入了标准库、pytorch 和 robomimic

2. 定义辅助函数 rollout

# 定义一个函数,用于执行一次策略的回合(rollout)
def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, return_obs=False, camera_names=None, real=False, rate_measure=None):
    """
    Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video, 
    and returns the rollout trajectory.

    Args:
        policy (instance of RolloutPolicy): policy loaded from a checkpoint
        env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
        horizon (int): maximum horizon for the rollout
        render (bool): whether to render rollout on-screen
        video_writer (imageio writer): if provided, use to write rollout to video
        video_skip (int): how often to write video frames
        return_obs (bool): if True, return possibly high-dimensional observations along the trajectory. 
            They are excluded by default because the low-dimensional simulation states should be a minimal 
            representation of the environment. 
        camera_names (list): determines which camera(s) are used for rendering. Pass more than
            one to output a video with multiple camera views concatenated horizontally.
        real (bool): if real robot rollout
        rate_measure: if provided, measure rate of action computation and do not play actions in environment

    Returns:
        stats (dict): some statistics for the rollout - such as return, horizon, and task success
        traj (dict): dictionary that corresponds to the rollout trajectory
    """
    rollout_timestamp = time.time()  # 记录当前时间戳,便于计算回合耗时
    assert isinstance(env, EnvBase) or isinstance(env, EnvWrapper)  # 确保环境是EnvBase或EnvWrapper的实例
    assert isinstance(policy, RolloutPolicy)  # 确保策略是RolloutPolicy的实例
    assert not (render and (video_writer is not None))  # 确保不会同时进行屏幕渲染和视频录制

    policy.start_episode()  # 通知策略开始一个新的回合
    obs = env.reset()  # 重置环境,获取初始观测
    state_dict = dict()  # 用于存储环境状态(主要用于模拟环境)
    if real:
        input("ready for next eval? hit enter to continue")  # 如果是真实机器人环境,等待用户确认开始
    else:
        state_dict = env.get_state()  # 获取当前环境状态
        # 对于robosuite任务,为了确保动作的可重复性,需要重置到获取的状态
        obs = env.reset_to(state_dict)

    results = {}  # 用于存储回合结果的字典
    video_count = 0  # 视频帧计数器
    total_reward = 0.  # 累积奖励
    got_exception = False  # 用于标记是否在回合中遇到异常
    success = env.is_success()["task"]  # 检查任务是否成功
    traj = dict(actions=[], rewards=[], dones=[], states=[], initial_state_dict=state_dict)  # 用于存储回合轨迹的数据
    if return_obs:
        # 如果需要返回观测数据,初始化相应的列表
        traj.update(dict(obs=[], next_obs=[]))
    try:
        for step_i in range(horizon):  # 开始回合的主循环,最多运行到指定的horizon步
            # HACK: some keys on real robot do not have a shape (and then they get frame stacked)
            for k in obs:
                if len(obs[k].shape) == 1:
                    obs[k] = obs[k][..., None]  # 确保观测数据的形状正确

            # get action from policy
            t1 = time.time()  # 记录动作计算开始时间
            act = policy(ob=obs)  # 使用策略计算动作
            t2 = time.time()  # 记录动作计算结束时间
            if real and (not env.base_env.controller_type == "JOINT_IMPEDANCE") and (policy.policy.global_config.algo_name != "diffusion_policy"):
                # 如果是真实机器人环境,且控制器不是JOINT_IMPEDANCE,且策略不是diffusion_policy
                # 为了安全,动作值需要裁剪到[-1, 1]范围内
                act = np.clip(act, -1., 1.)

            if rate_measure is not None:
                # 如果提供了速率测量工具,则测量动作计算速率,而不在环境中执行动作
                rate_measure.measure()
                print("time: {}s".format(t2 - t1))  # 打印动作计算耗时
                # dummy reward and done
                r = 0.
                done = False
                next_obs = obs
            else:
                # play action
                next_obs, r, done, _ = env.step(act)  # 在环境中执行动作,获取下一个观测、奖励、是否完成等信息

            # compute reward
            total_reward += r  # 累积奖励
            success = env.is_success()["task"]  # 检查任务是否成功

            # visualization
            if render:
                env.render(mode="human", camera_name=camera_names[0])  # 在屏幕上渲染环境
            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = []
                    for cam_name in camera_names:
                        # 从指定的摄像机获取渲染图像
                        video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                    video_img = np.concatenate(video_img, axis=1)  # 将多个摄像机的图像水平拼接
                    video_writer.append_data(video_img)  # 将图像写入视频
                video_count += 1  # 增加视频帧计数

            # collect transition
            traj["actions"].append(act)  # 记录动作
            traj["rewards"].append(r)  # 记录奖励
            traj["dones"].append(done)  # 记录是否完成
            if not real:
                traj["states"].append(state_dict["states"])  # 记录环境状态(仅用于模拟环境)
            if return_obs:
                # 需要将观测数据“反处理”,以便于保存到数据集中
                traj["obs"].append(ObsUtils.unprocess_obs_dict(obs))
                traj["next_obs"].append(ObsUtils.unprocess_obs_dict(next_obs))

            # break if done or if success
            if done or success:
                break  # 如果任务完成或达到终止条件,退出循环

            # update for next iter
            obs = deepcopy(next_obs)  # 更新当前观测
            if not real:
                state_dict = env.get_state()  # 更新环境状态(仅用于模拟环境)

    except env.rollout_exceptions as e:
        # 捕获在回合中可能发生的异常,记录警告信息
        print("WARNING: got rollout exception {}".format(e))
        got_exception = True

    stats = dict(
        Return=total_reward,  # 总奖励
        Horizon=(step_i + 1),  # 实际运行的步数
        Success_Rate=float(success),  # 任务成功率
        Exception_Rate=float(got_exception),  # 异常发生率
        time=(time.time() - rollout_timestamp),  # 回合耗时
    )

    if return_obs:
        # 将观测数据从列表转换为字典,以便于保存
        traj["obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["obs"])
        traj["next_obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["next_obs"])

    # list to numpy array
    for k in traj:
        if k == "initial_state_dict":
            continue  # 初始状态字典不需要转换
        if isinstance(traj[k], dict):
            for kp in traj[k]:
                traj[k][kp] = np.array(traj[k][kp])  # 将数据转换为NumPy数组
        else:
            traj[k] = np.array(traj[k])

    return stats, traj  # 返回回合统计信息和轨迹数据

执行一次在环境中的回合(rollout),即让策略在环境中运行一段时间,收集数据和统计信息

2.1 参数

  • policyinstance of RolloutPolicy:训练好的策略,用于在给定观察下输出动作,从 checkpoint 中加载
  • envinstance of EnvBase:环境实例,可以是模拟环境或真实机器人环境,从 checkpoint 或者演示元数据中加载
  • horizonint:回合的最大步数,即最大运行时间
  • renderbool:是否在屏幕上渲染 rollout
  • video_writerimageio writer:如果提供,将把渲染的图像写入视频文件
  • video_skipint:每隔多少步写入一帧到视频中
  • return_obsbool:是否在返回的数据中包含观测信息,True 则沿着轨迹返回可能的高维观测值,default 情况下被排除,因为低维模拟状态应该是环境的最小表示
  • camera_nameslist:用于渲染的摄像机名称列表
  • realbool:是否在真实机器人上运行
  • rate_measure:用于测量动作计算速率的工具

2.2 返回值

  • statsdict:rollout 统计数据,如 return, horizon, and task success
  • trajdict:rollout 轨迹数据(字典格式)

2.3 初始化

    rollout_timestamp = time.time()  # 记录当前时间戳,便于计算回合耗时
    assert isinstance(env, EnvBase) or isinstance(env, EnvWrapper)  # 确保环境是EnvBase或EnvWrapper的实例
    assert isinstance(policy, RolloutPolicy)  # 确保策略是RolloutPolicy的实例
    assert not (render and (video_writer is not None))  # 确保不会同时进行屏幕渲染和视频录制
  • 记录回合的开始时间
  • 检查 env 是否是合法的环境实例
  • 确保不会同时进行屏幕渲染和视频录制
    policy.start_episode()  # 通知策略开始一个新的回合
    obs = env.reset()  # 重置环境,获取初始观测
    state_dict = dict()  # 用于存储环境状态(主要用于模拟环境)
    if real:
        input("ready for next eval? hit enter to continue")  # 如果是真实机器人环境,等待用户确认开始
    else:
        state_dict = env.get_state()  # 获取当前环境状态
        # 对于robosuite任务,为了确保动作的可重复性,需要重置到获取的状态
        obs = env.reset_to(state_dict)

    results = {}  # 用于存储回合结果的字典
    video_count = 0  # 视频帧计数器
    total_reward = 0.  # 累积奖励
    got_exception = False  # 用于标记是否在回合中遇到异常
    success = env.is_success()["task"]  # 检查任务是否成功
    traj = dict(actions=[], rewards=[], dones=[], states=[], initial_state_dict=state_dict)  # 用于存储回合轨迹的数据
    if return_obs:
        # 如果需要返回观测数据,初始化相应的列表
        traj.update(dict(obs=[], next_obs=[]))
  • 调用 policy.start_episode() 开始新的策略回合
  • 重置环境,获取初始观测 obs
  • 如果是在真实机器人上运行,等待用户确认开始
  • 初始化用于存储回合数据的字典 traj 和统计信息的字典 stats

2.4 主循环

    try:
        for step_i in range(horizon):  # 开始回合的主循环,最多运行到指定的horizon步
            # HACK: some keys on real robot do not have a shape (and then they get frame stacked)
            for k in obs:
                if len(obs[k].shape) == 1:
                    obs[k] = obs[k][..., None]  # 确保观测数据的形状正确
  • 处理观测数据:确保观测数据的形状正确,以避免后续处理中的错误
            # get action from policy
            t1 = time.time()  # 记录动作计算开始时间
            act = policy(ob=obs)  # 使用策略计算动作
            t2 = time.time()  # 记录动作计算结束时间
            if real and (not env.base_env.controller_type == "JOINT_IMPEDANCE") and (policy.policy.global_config.algo_name != "diffusion_policy"):
                # 如果是真实机器人环境,且控制器不是JOINT_IMPEDANCE,且策略不是diffusion_policy
                # 为了安全,动作值需要裁剪到[-1, 1]范围内
                act = np.clip(act, -1., 1.)
  • 计算动作:使用策略根据当前观测计算下一步的动作
  • 测量计算时间:记录动作计算前后的时间,以评估计算性能
  • 动作裁剪:在真实机器人上,控制器不是 JOINT_IMPEDANCE,且策略不是 diffusion_policy,为了安全,可能需要将动作值裁剪到 [-1, 1] 的范围内
            if rate_measure is not None:
                # 如果提供了速率测量工具,则测量动作计算速率,而不在环境中执行动作
                rate_measure.measure()
                print("time: {}s".format(t2 - t1))  # 打印动作计算耗时
                # dummy reward and done
                r = 0.
                done = False
                next_obs = obs
            else:
                # play action
                next_obs, r, done, _ = env.step(act)  # 在环境中执行动作,获取下一个观测、奖励、是否完成等信息

            # compute reward
            total_reward += r  # 累积奖励
            success = env.is_success()["task"]  # 检查任务是否成功
  • 速率测量:如果提供了 rate_measure,则测量动作计算的速率,并跳过实际的环境交互
  • 环境交互:将动作应用于环境,获取下一个观测、奖励、是否完成以及额外信息
  • 收集奖励和成功信息:累积总奖励,并检查任务是否成功完成
            # visualization
            if render:
                env.render(mode="human", camera_name=camera_names[0])  # 在屏幕上渲染环境
            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = []
                    for cam_name in camera_names:
                        # 从指定的摄像机获取渲染图像
                        video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                    video_img = np.concatenate(video_img, axis=1)  # 将多个摄像机的图像水平拼接
                    video_writer.append_data(video_img)  # 将图像写入视频
                video_count += 1  # 增加视频帧计数
  • 渲染和视频录制:根据参数,进行屏幕渲染或将图像写入视频文件
            # collect transition
            traj["actions"].append(act)  # 记录动作
            traj["rewards"].append(r)  # 记录奖励
            traj["dones"].append(done)  # 记录是否完成
            if not real:
                traj["states"].append(state_dict["states"])  # 记录环境状态(仅用于模拟环境)
            if return_obs:
                # 需要将观测数据“反处理”,以便于保存到数据集中
                traj["obs"].append(ObsUtils.unprocess_obs_dict(obs))
                traj["next_obs"].append(ObsUtils.unprocess_obs_dict(next_obs))
  • 收集数据:将动作、奖励、观测等信息存储到 traj 字典中
            # break if done or if success
            if done or success:
                break  # 如果任务完成或达到终止条件,退出循环
  • 提前结束条件:如果任务成功或达到终止条件,提前结束回合
            # update for next iter
            obs = deepcopy(next_obs)  # 更新当前观测
            if not real:
                state_dict = env.get_state()  # 更新环境状态(仅用于模拟环境)
  • 更新观测和状态:为下一步的计算准备

2.5 异常处理

    except env.rollout_exceptions as e:
        # 捕获在回合中可能发生的异常,记录警告信息
        print("WARNING: got rollout exception {}".format(e))
        got_exception = True
  • 捕获在回合过程中可能发生的异常,记录警告信息

2.6 回合结束处理

    stats = dict(
        Return=total_reward,  # 总奖励
        Horizon=(step_i + 1),  # 实际运行的步数
        Success_Rate=float(success),  # 任务成功率
        Exception_Rate=float(got_exception),  # 异常发生率
        time=(time.time() - rollout_timestamp),  # 回合耗时
    )

    if return_obs:
        # 将观测数据从列表转换为字典,以便于保存
        traj["obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["obs"])
        traj["next_obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["next_obs"])
  • 计算回合的统计信息,包括总奖励、步数、是否成功、异常发生率和总时间
    # list to numpy array
    for k in traj:
        if k == "initial_state_dict":
            continue  # 初始状态字典不需要转换
        if isinstance(traj[k], dict):
            for kp in traj[k]:
                traj[k][kp] = np.array(traj[k][kp])  # 将数据转换为NumPy数组
        else:
            traj[k] = np.array(traj[k])
  • 将收集的数据从列表转换为 NumPy 数组,便于后续处理或保存

2.7 返回值

    return stats, traj  # 返回回合统计信息和轨迹数据
  • stats:包含回合统计信息的字典
  • traj:包含回合轨迹数据的字典

3. 定义主函数 run_trained_agent

# 主函数,用于根据命令行参数执行策略评估
def run_trained_agent(args):
    # 一些参数检查
    write_video = (args.video_path is not None)  # 判断是否需要录制视频
    assert not (args.render and write_video)  # 确保不会同时进行屏幕渲染和视频录制

    rate_measure = None  # 用于测量动作计算速率的工具
    if args.hz is not None:
        import RobotTeleop
        from RobotTeleop.utils import Rate, RateMeasure, Timers
        rate_measure = RateMeasure(name="control_rate_measure", freq_threshold=args.hz)
    
    # 加载策略检查点,并获取算法名称
    algo_name, ckpt_dict = FileUtils.algo_name_from_checkpoint(ckpt_path=args.agent)

    if args.dp_eval_steps is not None:
        # 如果指定了dp_eval_steps参数,并且算法是diffusion_policy,则修改配置
        assert algo_name == "diffusion_policy"
        log_warning("setting @num_inference_steps to {}".format(args.dp_eval_steps))

        # 修改配置,然后重新写回ckpt_dict
        tmp_config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
        with tmp_config.values_unlocked():
            if tmp_config.algo.ddpm.enabled:
                tmp_config.algo.ddpm.num_inference_timesteps = args.dp_eval_steps
            elif tmp_config.algo.ddim.enabled:
                tmp_config.algo.ddim.num_inference_timesteps = args.dp_eval_steps
            else:
                raise Exception("should not reach here")
        ckpt_dict['config'] = tmp_config.dump()

    # 确定设备(CPU或GPU)
    device = TorchUtils.get_torch_device(try_to_use_cuda=True)

    # 从检查点恢复策略
    policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_dict=ckpt_dict, device=device, verbose=True)

    # 读取回合设置
    rollout_num_episodes = args.n_rollouts  # 回合数量
    rollout_horizon = args.horizon  # 回合最大步数
    config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
    if rollout_horizon is None:
        # 如果未指定horizon参数,则从配置中获取
        rollout_horizon = config.experiment.rollout.horizon

    # HACK: assume absolute actions for now if using diffusion policy on real robot
    if (algo_name == "diffusion_policy") and EnvUtils.is_real_robot_gprs_env(env_meta=ckpt_dict["env_metadata"]):
        ckpt_dict["env_metadata"]["env_kwargs"]["absolute_actions"] = True

    # 从保存的检查点创建环境
    env, _ = FileUtils.env_from_checkpoint(
        ckpt_dict=ckpt_dict, 
        env_name=args.env, 
        render=args.render, 
        render_offscreen=(args.video_path is not None), 
        verbose=True,
    )

    # 如果未指定camera_names,则使用默认摄像机配置
    if args.camera_names is None:
        # We fill in the automatic values
        env_type = EnvUtils.get_env_type(env=env)
        args.camera_names = DEFAULT_CAMERAS[env_type]
    if args.render:
        # 屏幕渲染只能支持一个摄像机
        assert len(args.camera_names) == 1

    is_real_robot = EnvUtils.is_real_robot_env(env=env) or EnvUtils.is_real_robot_gprs_env(env=env)
    if is_real_robot:
        # 如果是真实机器人环境,记录一些警告信息
        need_pause = False
        if "env_name" not in ckpt_dict["env_metadata"]["env_kwargs"]:
            log_warning("env_name not in checkpoint...proceed with caution...")
            need_pause = True
        if ckpt_dict["env_metadata"]["env_name"] != "EnvRealPandaGPRS":
            # 即使策略是在不同的环境类中收集的,我们也将默认加载EnvRealPandaGPRS类
            log_warning("env name in metadata appears to be class ({}) different from EnvRealPandaGPRS".format(ckpt_dict["env_metadata"]["env_name"]))
            need_pause = True
        if need_pause:
            ans = input("continue? (y/n)")
            if ans != "y":
                exit()

    # 如果提供了种子参数,则设置随机种子
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # 可能需要创建视频写入器
    video_writer = None
    if write_video:
        video_writer = imageio.get_writer(args.video_path, fps=20)

    # 如果需要写入数据集,打开HDF5文件
    write_dataset = (args.dataset_path is not None)
    if write_dataset:
        data_writer = h5py.File(args.dataset_path, "w")
        data_grp = data_writer.create_group("data")
        total_samples = 0

    rollout_stats = []  # 用于存储每个回合的统计信息
    for i in tqdm(range(rollout_num_episodes)):  # 使用进度条显示回合进度
        try:
            stats, traj = rollout(
                policy=policy, 
                env=env, 
                horizon=rollout_horizon, 
                render=args.render, 
                video_writer=video_writer, 
                video_skip=args.video_skip, 
                return_obs=(write_dataset and args.dataset_obs),
                camera_names=args.camera_names,
                real=is_real_robot,
                rate_measure=rate_measure,
            )
        except KeyboardInterrupt:
            if is_real_robot:
                print("ctrl-C catched, stop execution")
                print("env rate measure")
                print(env.rate_measure)
                ans = input("success? (y / n)")
                rollout_stats.append((1 if ans == "y" else 0))
                print("*" * 50)
                print("have {} success out of {} attempts".format(np.sum(rollout_stats), len(rollout_stats)))
                print("*" * 50)
                continue
            else:
                sys.exit(0)
        
        if is_real_robot:
            print("TERMINATE WITHOUT KEYBOARD INTERRUPT...")
            ans = input("success? (y / n)")
            rollout_stats.append((1 if ans == "y" else 0))
            continue
        rollout_stats.append(stats)

        if write_dataset:
            # 将数据集写入HDF5文件
            ep_data_grp = data_grp.create_group("demo_{}".format(i))
            ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]))
            ep_data_grp.create_dataset("states", data=np.array(traj["states"]))
            ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"]))
            ep_data_grp.create_dataset("dones", data=np.array(traj["dones"]))
            if args.dataset_obs:
                for k in traj["obs"]:
                    ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k]))
                    ep_data_grp.create_dataset("next_obs/{}".format(k), data=np.array(traj["next_obs"][k]))

            # 记录每个回合的元数据
            if "model" in traj["initial_state_dict"]:
                ep_data_grp.attrs["model_file"] = traj["initial_state_dict"]["model"]  # 该回合的模型XML文件
            ep_data_grp.attrs["num_samples"] = traj["actions"].shape[0]  # 该回合的样本数量
            total_samples += traj["actions"].shape[0]  # 更新总样本数

    rollout_stats = TensorUtils.list_of_flat_dict_to_dict_of_list(rollout_stats)
    avg_rollout_stats = { k : np.mean(rollout_stats[k]) for k in rollout_stats }
    avg_rollout_stats["Num_Success"] = np.sum(rollout_stats["Success_Rate"])
    avg_rollout_stats["Time_Episode"] = np.sum(rollout_stats["time"]) / 60.  # 总耗时(分钟)
    avg_rollout_stats["Num_Episode"] = len(rollout_stats["Success_Rate"])  # 总回合数
    print("Average Rollout Stats")
    stats_json = json.dumps(avg_rollout_stats, indent=4)
    print(stats_json)
    if args.json_path is not None:
        json_f = open(args.json_path, "w")
        json_f.write(stats_json)
        json_f.close()

    if write_video:
        video_writer.close()

    if write_dataset:
        # 写入全局元数据
        data_grp.attrs["total"] = total_samples
        data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4)  # 环境信息
        data_writer.close()
        print("Wrote dataset trajectories to {}".format(args.dataset_path))

根据命令行参数设置,加载策略和环境,执行多个回合,并根据需要进行渲染、视频录制和数据保存

3.1 参数

  • args:解析后的命令行参数

3.2 参数检查和设置

    # 一些参数检查
    write_video = (args.video_path is not None)  # 判断是否需要录制视频
    assert not (args.render and write_video)  # 确保不会同时进行屏幕渲染和视频录制

    rate_measure = None  # 用于测量动作计算速率的工具
    if args.hz is not None:
        import RobotTeleop
        from RobotTeleop.utils import Rate, RateMeasure, Timers
        rate_measure = RateMeasure(name="control_rate_measure", freq_threshold=args.hz)
  • 检查是否同时要求屏幕渲染和视频录制,如果是,抛出错误
  • 如果提供了动作计算频率参数 args.hz,导入相关的速率测量工具

***3.3 加载策略检查点

    # 加载策略检查点,并获取算法名称
    algo_name, ckpt_dict = FileUtils.algo_name_from_checkpoint(ckpt_path=args.agent)
  • 使用 FileUtils.algo_name_from_checkpoint 获取策略的算法名称和检查点字典 ckpt_dict
    if args.dp_eval_steps is not None:
        # 如果指定了dp_eval_steps参数,并且算法是diffusion_policy,则修改配置
        assert algo_name == "diffusion_policy"
        log_warning("setting @num_inference_steps to {}".format(args.dp_eval_steps))

        # 修改配置,然后重新写回ckpt_dict
        tmp_config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
        with tmp_config.values_unlocked():
            if tmp_config.algo.ddpm.enabled:
                tmp_config.algo.ddpm.num_inference_timesteps = args.dp_eval_steps
            elif tmp_config.algo.ddim.enabled:
                tmp_config.algo.ddim.num_inference_timesteps = args.dp_eval_steps
            else:
                raise Exception("should not reach here")
        ckpt_dict['config'] = tmp_config.dump()
  • 如果提供了 dp_eval_steps 参数,并且算法是扩散策略,则修改策略的配置,以设置推理步骤数
    # 确定设备(CPU或GPU)
    device = TorchUtils.get_torch_device(try_to_use_cuda=True)

    # 从检查点恢复策略
    policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_dict=ckpt_dict, device=device, verbose=True)
  • 确定计算设备(CPU 或 GPU)
  • 使用 FileUtils.policy_from_checkpoint 从检查点加载策略

3.4 设置回合参数

    # 读取回合设置
    rollout_num_episodes = args.n_rollouts  # 回合数量
    rollout_horizon = args.horizon  # 回合最大步数
    config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
    if rollout_horizon is None:
        # 如果未指定horizon参数,则从配置中获取
        rollout_horizon = config.experiment.rollout.horizon
  • 确定回合数量 rollout_num_episodes 和回合最大步数 rollout_num_episodes
  • 如果没有提供 args.horizon,则从策略配置中获取
    # HACK: assume absolute actions for now if using diffusion policy on real robot
    if (algo_name == "diffusion_policy") and EnvUtils.is_real_robot_gprs_env(env_meta=ckpt_dict["env_metadata"]):
        ckpt_dict["env_metadata"]["env_kwargs"]["absolute_actions"] = True
  • HACK: 如果在真实机器人上使用扩散策略,目前假设采用绝对动作

  • ["env_metadata"]["env_kwargs"]["absolute_actions"] 设置为 True

3.5 创建环境

    # 从保存的检查点创建环境
    env, _ = FileUtils.env_from_checkpoint(
        ckpt_dict=ckpt_dict, 
        env_name=args.env, 
        render=args.render, 
        render_offscreen=(args.video_path is not None), 
        verbose=True,
    )
  • 使用 FileUtils.env_from_checkpoint 根据检查点字典创建环境
  • 根据参数设置渲染选项
    # 如果未指定camera_names,则使用默认摄像机配置
    if args.camera_names is None:
        # We fill in the automatic values
        env_type = EnvUtils.get_env_type(env=env)
        args.camera_names = DEFAULT_CAMERAS[env_type]
    if args.render:
        # 屏幕渲染只能支持一个摄像机
        assert len(args.camera_names) == 1
  • 如果没有提供 ,则使用默认的摄像机名称
    is_real_robot = EnvUtils.is_real_robot_env(env=env) or EnvUtils.is_real_robot_gprs_env(env=env)
    if is_real_robot:
        # 如果是真实机器人环境,记录一些警告信息
        need_pause = False
        if "env_name" not in ckpt_dict["env_metadata"]["env_kwargs"]:
            log_warning("env_name not in checkpoint...proceed with caution...")
            need_pause = True
        if ckpt_dict["env_metadata"]["env_name"] != "EnvRealPandaGPRS":
            # 即使策略是在不同的环境类中收集的,我们也将默认加载EnvRealPandaGPRS类
            log_warning("env name in metadata appears to be class ({}) different from EnvRealPandaGPRS".format(ckpt_dict["env_metadata"]["env_name"]))
            need_pause = True
        if need_pause:
            ans = input("continue? (y/n)")
            if ans != "y":
                exit()
  • 检查是否是真实机器人环境,并根据需要记录警告信息

3.6 设置随机种子

    # 如果提供了种子参数,则设置随机种子
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
  • 如果提供了 args.seed,则设置 NumPy 和 PyTorch 的随机种子,以确保结果的可重复性

3.7 初始化视频和数据写入器

    # 可能需要创建视频写入器
    video_writer = None
    if write_video:
        video_writer = imageio.get_writer(args.video_path, fps=20)

    # 如果需要写入数据集,打开HDF5文件
    write_dataset = (args.dataset_path is not None)
    if write_dataset:
        data_writer = h5py.File(args.dataset_path, "w")
        data_grp = data_writer.create_group("data")
        total_samples = 0
  • 如果提供了 args.video_path,则创建视频写入器 video_writer
  • 如果提供了 args.dataset_path,则打开 HDF5 文件以保存回合数据

3.8 执行回合

    rollout_stats = []  # 用于存储每个回合的统计信息
    for i in tqdm(range(rollout_num_episodes)):  # 使用进度条显示回合进度
        try:
            stats, traj = rollout(
                policy=policy, 
                env=env, 
                horizon=rollout_horizon, 
                render=args.render, 
                video_writer=video_writer, 
                video_skip=args.video_skip, 
                return_obs=(write_dataset and args.dataset_obs),
                camera_names=args.camera_names,
                real=is_real_robot,
                rate_measure=rate_measure,
            )
  • 初始化一个列表 rollout_stats,用于存储每个回合的统计信息
  • 使用 tqdm 创建一个进度条,遍历回合数
  • 在每个回合中,调用 rollout 函数执行回合,并捕获统计信息和轨迹数据
        except KeyboardInterrupt:
            if is_real_robot:
                print("ctrl-C catched, stop execution")
                print("env rate measure")
                print(env.rate_measure)
                ans = input("success? (y / n)")
                rollout_stats.append((1 if ans == "y" else 0))
                print("*" * 50)
                print("have {} success out of {} attempts".format(np.sum(rollout_stats), len(rollout_stats)))
                print("*" * 50)
                continue
            else:
                sys.exit(0)
        
        if is_real_robot:
            print("TERMINATE WITHOUT KEYBOARD INTERRUPT...")
            ans = input("success? (y / n)")
            rollout_stats.append((1 if ans == "y" else 0))
            continue
        rollout_stats.append(stats)
  • 如果是在真实机器人上运行,捕获 KeyboardInterrupt 异常,并提示用户输入回合是否成功
        if write_dataset:
            # 将数据集写入HDF5文件
            ep_data_grp = data_grp.create_group("demo_{}".format(i))
            ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]))
            ep_data_grp.create_dataset("states", data=np.array(traj["states"]))
            ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"]))
            ep_data_grp.create_dataset("dones", data=np.array(traj["dones"]))
            if args.dataset_obs:
                for k in traj["obs"]:
                    ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k]))
                    ep_data_grp.create_dataset("next_obs/{}".format(k), data=np.array(traj["next_obs"][k]))

            # 记录每个回合的元数据
            if "model" in traj["initial_state_dict"]:
                ep_data_grp.attrs["model_file"] = traj["initial_state_dict"]["model"]  # 该回合的模型XML文件
            ep_data_grp.attrs["num_samples"] = traj["actions"].shape[0]  # 该回合的样本数量
            total_samples += traj["actions"].shape[0]  # 更新总样本数
  • 如果需要保存数据,将轨迹数据写入 HDF5 文件

3.9 回合结束处理

    rollout_stats = TensorUtils.list_of_flat_dict_to_dict_of_list(rollout_stats)
    avg_rollout_stats = { k : np.mean(rollout_stats[k]) for k in rollout_stats }
    avg_rollout_stats["Num_Success"] = np.sum(rollout_stats["Success_Rate"])
    avg_rollout_stats["Time_Episode"] = np.sum(rollout_stats["time"]) / 60.  # 总耗时(分钟)
    avg_rollout_stats["Num_Episode"] = len(rollout_stats["Success_Rate"])  # 总回合数
    print("Average Rollout Stats")
    stats_json = json.dumps(avg_rollout_stats, indent=4)
    print(stats_json)
  • 计算所有回合的平均统计信息,并打印输出
    if args.json_path is not None:
        json_f = open(args.json_path, "w")
        json_f.write(stats_json)
        json_f.close()

    if write_video:
        video_writer.close()

    if write_dataset:
        # 写入全局元数据
        data_grp.attrs["total"] = total_samples
        data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4)  # 环境信息
        data_writer.close()
        print("Wrote dataset trajectories to {}".format(args.dataset_path))
  • 如果提供了 args.json_path,将统计信息保存为 JSON 文件
  • 关闭视频和数据写入器

4. 解析命令行参数并执行脚本

# 主程序入口
if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # 添加命令行参数
    # Path to trained model
    parser.add_argument(
        "--agent",
        type=str,
        required=True,
        help="path to saved checkpoint pth file",
    )

    # number of rollouts
    parser.add_argument(
        "--n_rollouts",
        type=int,
        default=27,
        help="number of rollouts",
    )

    # maximum horizon of rollout, to override the one stored in the model checkpoint
    parser.add_argument(
        "--horizon",
        type=int,
        default=None,
        help="(optional) override maximum horizon of rollout from the one in the checkpoint",
    )

    # Env Name (to override the one stored in model checkpoint)
    parser.add_argument(
        "--env",
        type=str,
        default=None,
        help="(optional) override name of env from the one in the checkpoint, and use\
            it for rollouts",
    )

    # Whether to render rollouts to screen
    parser.add_argument(
        "--render",
        action='store_true',
        help="on-screen rendering",
    )

    # Dump a video of the rollouts to the specified path
    parser.add_argument(
        "--video_path",
        type=str,
        default=None,
        help="(optional) render rollouts to this video file path",
    )

    # How often to write video frames during the rollout
    parser.add_argument(
        "--video_skip",
        type=int,
        default=5,
        help="render frames to video every n steps",
    )

    # camera names to render
    parser.add_argument(
        "--camera_names",
        type=str,
        nargs='+',
        default=None,
        help="(optional) camera name(s) to use for rendering on-screen or to video",
    )

    # If provided, an hdf5 file will be written with the rollout data
    parser.add_argument(
        "--dataset_path",
        type=str,
        default=None,
        help="(optional) if provided, an hdf5 file will be written at this path with the rollout data",
    )

    # If True and @dataset_path is supplied, will write possibly high-dimensional observations to dataset.
    parser.add_argument(
        "--dataset_obs",
        action='store_true',
        help="include possibly high-dimensional observations in output dataset hdf5 file (by default,\
            observations are excluded and only simulator states are saved)",
    )

    # for seeding before starting rollouts
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="(optional) set seed for rollouts",
    )

    # Dump a json of the rollout results stats to the specified path
    parser.add_argument(
        "--json_path",
        type=str,
        default=None,
        help="(optional) dump a json of the rollout results stats to the specified path",
    )

    # Dump a file with the error traceback at this path. Only created if run fails with an error.
    parser.add_argument(
        "--error_path",
        type=str,
        default=None,
        help="(optional) dump a file with the error traceback at this path. Only created if run fails with an error.",
    )

    # TODO: clean up this arg
    # If provided, do not run actions in env, and instead just measure the rate of action computation
    parser.add_argument(
        "--hz",
        type=int,
        default=None,
        help="If provided, do not run actions in env, and instead just measure the rate of action computation and raise warnings if it dips below this threshold",
    )

    # TODO: clean up this arg
    # If provided, set num_inference_timesteps explicitly for diffusion policy evaluation
    parser.add_argument(
        "--dp_eval_steps",
        type=int,
        default=None,
        help="If provided, set num_inference_timesteps explicitly for diffusion policy evaluation",
    )

    args = parser.parse_args()
    res_str = None
    try:
        run_trained_agent(args)  # 调用主函数,开始策略评估
    except Exception as e:
        res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
        if args.error_path is not None:
            # 将错误信息写入指定的错误文件
            f = open(args.error_path, "w")
            f.write(res_str)
            f.close()
        raise e  # 重新抛出异常

使用 argparse 模块解析命令行参数,支持多种参数配置

如果在执行过程中发生异常,捕获并将错误信息写入指定的错误文件,然后重新抛出异常

4.1 命令行参数

    # Path to trained model
    parser.add_argument(
        "--agent",
        type=str,
        required=True,
        help="path to saved checkpoint pth file",
    )
  • --agent:必须提供,指定训练好的策略检查点文件路径
    # number of rollouts
    parser.add_argument(
        "--n_rollouts",
        type=int,
        default=27,
        help="number of rollouts",
    )
  • --n_rollouts:执行的回合数量,默认值为 27
    # maximum horizon of rollout, to override the one stored in the model checkpoint
    parser.add_argument(
        "--horizon",
        type=int,
        default=None,
        help="(optional) override maximum horizon of rollout from the one in the checkpoint",
    )
  • --horizon:每个回合的最大步数,可选
    # Env Name (to override the one stored in model checkpoint)
    parser.add_argument(
        "--env",
        type=str,
        default=None,
        help="(optional) override name of env from the one in the checkpoint, and use\
            it for rollouts",
    )
  • --env:指定环境名称,覆盖检查点中的环境配置
    # Whether to render rollouts to screen
    parser.add_argument(
        "--render",
        action='store_true',
        help="on-screen rendering",
    )
  • --render:启用屏幕渲染
    # Dump a video of the rollouts to the specified path
    parser.add_argument(
        "--video_path",
        type=str,
        default=None,
        help="(optional) render rollouts to this video file path",
    )
  • --video_path:指定视频保存路径,启用视频录制
    # How often to write video frames during the rollout
    parser.add_argument(
        "--video_skip",
        type=int,
        default=5,
        help="render frames to video every n steps",
    )
  • --video_skip:每隔多少步保存一帧到视频
    # camera names to render
    parser.add_argument(
        "--camera_names",
        type=str,
        nargs='+',
        default=None,
        help="(optional) camera name(s) to use for rendering on-screen or to video",
    )
  • --camera_names:指定用于渲染的摄像机名称列表
    # If provided, an hdf5 file will be written with the rollout data
    parser.add_argument(
        "--dataset_path",
        type=str,
        default=None,
        help="(optional) if provided, an hdf5 file will be written at this path with the rollout data",
    )
  • --dataset_path:指定数据集保存路径,启用数据保存
    # If True and @dataset_path is supplied, will write possibly high-dimensional observations to dataset.
    parser.add_argument(
        "--dataset_obs",
        action='store_true',
        help="include possibly high-dimensional observations in output dataset hdf5 file (by default,\
            observations are excluded and only simulator states are saved)",
    )
  • --dataset_obs:在保存数据时包含观测信息
    # for seeding before starting rollouts
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="(optional) set seed for rollouts",
    )
  • --seed:设置随机种子
    # Dump a json of the rollout results stats to the specified path
    parser.add_argument(
        "--json_path",
        type=str,
        default=None,
        help="(optional) dump a json of the rollout results stats to the specified path",
    )
  • --json_path:指定统计信息保存为 JSON 文件的路径
    # Dump a file with the error traceback at this path. Only created if run fails with an error.
    parser.add_argument(
        "--error_path",
        type=str,
        default=None,
        help="(optional) dump a file with the error traceback at this path. Only created if run fails with an error.",
    )
  • --error_path:指定错误信息保存路径
    # TODO: clean up this arg
    # If provided, do not run actions in env, and instead just measure the rate of action computation
    parser.add_argument(
        "--hz",
        type=int,
        default=None,
        help="If provided, do not run actions in env, and instead just measure the rate of action computation and raise warnings if it dips below this threshold",
    )
  • --hz:指定动作计算频率,用于速率测量
    # TODO: clean up this arg
    # If provided, set num_inference_timesteps explicitly for diffusion policy evaluation
    parser.add_argument(
        "--dp_eval_steps",
        type=int,
        default=None,
        help="If provided, set num_inference_timesteps explicitly for diffusion policy evaluation",
    )
  • --dp_eval_steps:设置扩散策略的推理步骤数

4.2 代码运行

    args = parser.parse_args()
    res_str = None
    try:
        run_trained_agent(args)  # 调用主函数,开始策略评估
    except Exception as e:
        res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
        if args.error_path is not None:
            # 将错误信息写入指定的错误文件
            f = open(args.error_path, "w")
            f.write(res_str)
            f.close()
        raise e  # 重新抛出异常

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值