Dexcap源码中,Acknowledgements代表了部署通讯过程:
- Our policy training is implemented based on robomimic, Diffusion Policy.
- The robot arm controller is based on Deoxys.
- The robot LEAP hand controller is based on LEAP_Hand_API.
而 run_trained_agent.py 文件,是 robomimic 中明确的仿真 eval 脚本
因此,Dexcap 的部署,也和 run_trained_agent.py 有着密切关联
前置基础查看 robomimic应用教程(二)——策略运行与评估
目录
一、部署逻辑
有一个核心逻辑在于:dexcap 从相机获取输入,然后通过 diffusion policy生产动作轨迹,最后发送给机器人和末端执行器执行
1. 输入
pth 文件即扩散策略模型文件
2. 模型
训练得到 pth 文件,该文件包含了扩散策略模型的参数(权重和偏置)
3. 输出
训练后的 policy 部署后会输出 46 维的动作空间,组成如下(按顺序):
-
3维右臂末端执行器的平移: 控制右臂末端在空间中的位置移动
-
3维左臂末端执行器的平移: 控制左臂末端在空间中的位置移动
-
4维右臂末端执行器的四元数方向: 使用四元数表示右臂末端的旋转方向
-
4维左臂末端执行器的四元数方向: 使用四元数表示左臂末端的旋转方向
-
16维右手LEAP手部关节位置: 控制右手LEAP机械手的各个关节位置
-
16维左手LEAP手部关节位置: 控制左手LEAP机械手的各个关节位置
目前只能通过真实的机器人来验证学习到的策略
机器人:利用输出的7维(3维平移 + 4维方向)数据进行末端执行器的位置控制
灵巧手:按照开源的 LEAPHAND 手册自行构建
手腕和手指的重新定位是分开的,首先执行末端执行器位置控制,让机器人达到从人类数据(手腕相机帧)捕获的手掌姿势,然后在手腕帧内执行指尖IK
二、代码解析
此脚本的主要功能是使用 robomimic 库在环境中评估训练好的 policy
支持在屏幕上渲染、录制视频、保存回放数据等功能
1. 库及模块导入
# 导入所需的库
import argparse # 用于解析命令行参数
import os
import json
import h5py # 用于处理HDF5文件格式
import imageio # 用于读写图像和视频
import sys
import time
import traceback # 用于获取异常的详细信息
import numpy as np
from copy import deepcopy # 用于深拷贝对象
from tqdm import tqdm # 用于显示进度条
import torch # PyTorch库
# 导入robomimic库及其子模块
import robomimic
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.utils.log_utils import log_warning
from robomimic.envs.env_base import EnvBase
from robomimic.envs.wrappers import EnvWrapper
from robomimic.algo import RolloutPolicy
from robomimic.scripts.playback_dataset import DEFAULT_CAMERAS # 默认摄像机配置
导入了标准库、pytorch 和 robomimic
2. 定义辅助函数 rollout
# 定义一个函数,用于执行一次策略的回合(rollout)
def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, return_obs=False, camera_names=None, real=False, rate_measure=None):
"""
Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video,
and returns the rollout trajectory.
Args:
policy (instance of RolloutPolicy): policy loaded from a checkpoint
env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
horizon (int): maximum horizon for the rollout
render (bool): whether to render rollout on-screen
video_writer (imageio writer): if provided, use to write rollout to video
video_skip (int): how often to write video frames
return_obs (bool): if True, return possibly high-dimensional observations along the trajectory.
They are excluded by default because the low-dimensional simulation states should be a minimal
representation of the environment.
camera_names (list): determines which camera(s) are used for rendering. Pass more than
one to output a video with multiple camera views concatenated horizontally.
real (bool): if real robot rollout
rate_measure: if provided, measure rate of action computation and do not play actions in environment
Returns:
stats (dict): some statistics for the rollout - such as return, horizon, and task success
traj (dict): dictionary that corresponds to the rollout trajectory
"""
rollout_timestamp = time.time() # 记录当前时间戳,便于计算回合耗时
assert isinstance(env, EnvBase) or isinstance(env, EnvWrapper) # 确保环境是EnvBase或EnvWrapper的实例
assert isinstance(policy, RolloutPolicy) # 确保策略是RolloutPolicy的实例
assert not (render and (video_writer is not None)) # 确保不会同时进行屏幕渲染和视频录制
policy.start_episode() # 通知策略开始一个新的回合
obs = env.reset() # 重置环境,获取初始观测
state_dict = dict() # 用于存储环境状态(主要用于模拟环境)
if real:
input("ready for next eval? hit enter to continue") # 如果是真实机器人环境,等待用户确认开始
else:
state_dict = env.get_state() # 获取当前环境状态
# 对于robosuite任务,为了确保动作的可重复性,需要重置到获取的状态
obs = env.reset_to(state_dict)
results = {} # 用于存储回合结果的字典
video_count = 0 # 视频帧计数器
total_reward = 0. # 累积奖励
got_exception = False # 用于标记是否在回合中遇到异常
success = env.is_success()["task"] # 检查任务是否成功
traj = dict(actions=[], rewards=[], dones=[], states=[], initial_state_dict=state_dict) # 用于存储回合轨迹的数据
if return_obs:
# 如果需要返回观测数据,初始化相应的列表
traj.update(dict(obs=[], next_obs=[]))
try:
for step_i in range(horizon): # 开始回合的主循环,最多运行到指定的horizon步
# HACK: some keys on real robot do not have a shape (and then they get frame stacked)
for k in obs:
if len(obs[k].shape) == 1:
obs[k] = obs[k][..., None] # 确保观测数据的形状正确
# get action from policy
t1 = time.time() # 记录动作计算开始时间
act = policy(ob=obs) # 使用策略计算动作
t2 = time.time() # 记录动作计算结束时间
if real and (not env.base_env.controller_type == "JOINT_IMPEDANCE") and (policy.policy.global_config.algo_name != "diffusion_policy"):
# 如果是真实机器人环境,且控制器不是JOINT_IMPEDANCE,且策略不是diffusion_policy
# 为了安全,动作值需要裁剪到[-1, 1]范围内
act = np.clip(act, -1., 1.)
if rate_measure is not None:
# 如果提供了速率测量工具,则测量动作计算速率,而不在环境中执行动作
rate_measure.measure()
print("time: {}s".format(t2 - t1)) # 打印动作计算耗时
# dummy reward and done
r = 0.
done = False
next_obs = obs
else:
# play action
next_obs, r, done, _ = env.step(act) # 在环境中执行动作,获取下一个观测、奖励、是否完成等信息
# compute reward
total_reward += r # 累积奖励
success = env.is_success()["task"] # 检查任务是否成功
# visualization
if render:
env.render(mode="human", camera_name=camera_names[0]) # 在屏幕上渲染环境
if video_writer is not None:
if video_count % video_skip == 0:
video_img = []
for cam_name in camera_names:
# 从指定的摄像机获取渲染图像
video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
video_img = np.concatenate(video_img, axis=1) # 将多个摄像机的图像水平拼接
video_writer.append_data(video_img) # 将图像写入视频
video_count += 1 # 增加视频帧计数
# collect transition
traj["actions"].append(act) # 记录动作
traj["rewards"].append(r) # 记录奖励
traj["dones"].append(done) # 记录是否完成
if not real:
traj["states"].append(state_dict["states"]) # 记录环境状态(仅用于模拟环境)
if return_obs:
# 需要将观测数据“反处理”,以便于保存到数据集中
traj["obs"].append(ObsUtils.unprocess_obs_dict(obs))
traj["next_obs"].append(ObsUtils.unprocess_obs_dict(next_obs))
# break if done or if success
if done or success:
break # 如果任务完成或达到终止条件,退出循环
# update for next iter
obs = deepcopy(next_obs) # 更新当前观测
if not real:
state_dict = env.get_state() # 更新环境状态(仅用于模拟环境)
except env.rollout_exceptions as e:
# 捕获在回合中可能发生的异常,记录警告信息
print("WARNING: got rollout exception {}".format(e))
got_exception = True
stats = dict(
Return=total_reward, # 总奖励
Horizon=(step_i + 1), # 实际运行的步数
Success_Rate=float(success), # 任务成功率
Exception_Rate=float(got_exception), # 异常发生率
time=(time.time() - rollout_timestamp), # 回合耗时
)
if return_obs:
# 将观测数据从列表转换为字典,以便于保存
traj["obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["obs"])
traj["next_obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["next_obs"])
# list to numpy array
for k in traj:
if k == "initial_state_dict":
continue # 初始状态字典不需要转换
if isinstance(traj[k], dict):
for kp in traj[k]:
traj[k][kp] = np.array(traj[k][kp]) # 将数据转换为NumPy数组
else:
traj[k] = np.array(traj[k])
return stats, traj # 返回回合统计信息和轨迹数据
执行一次在环境中的回合(rollout),即让策略在环境中运行一段时间,收集数据和统计信息
2.1 参数
policy
(
instance of RolloutPolicy)
:训练好的策略,用于在给定观察下输出动作,从 checkpoint 中加载env(
instance of EnvBase)
:环境实例,可以是模拟环境或真实机器人环境,从 checkpoint 或者演示元数据中加载horizon(
int)
:回合的最大步数,即最大运行时间render(
bool)
:是否在屏幕上渲染 rolloutvideo_writer(
imageio writer)
:如果提供,将把渲染的图像写入视频文件video_skip
(
int)
:每隔多少步写入一帧到视频中return_obs(
bool)
:是否在返回的数据中包含观测信息,True 则沿着轨迹返回可能的高维观测值,default 情况下被排除,因为低维模拟状态应该是环境的最小表示camera_names(
list)
:用于渲染的摄像机名称列表real(
bool)
:是否在真实机器人上运行rate_measure
:用于测量动作计算速率的工具
2.2 返回值
- stats
(
dict)
:rollout 统计数据,如 return, horizon, and task success - traj
(
dict)
:rollout 轨迹数据(字典格式)
2.3 初始化
rollout_timestamp = time.time() # 记录当前时间戳,便于计算回合耗时
assert isinstance(env, EnvBase) or isinstance(env, EnvWrapper) # 确保环境是EnvBase或EnvWrapper的实例
assert isinstance(policy, RolloutPolicy) # 确保策略是RolloutPolicy的实例
assert not (render and (video_writer is not None)) # 确保不会同时进行屏幕渲染和视频录制
- 记录回合的开始时间
- 检查 env 是否是合法的环境实例
- 确保不会同时进行屏幕渲染和视频录制
policy.start_episode() # 通知策略开始一个新的回合
obs = env.reset() # 重置环境,获取初始观测
state_dict = dict() # 用于存储环境状态(主要用于模拟环境)
if real:
input("ready for next eval? hit enter to continue") # 如果是真实机器人环境,等待用户确认开始
else:
state_dict = env.get_state() # 获取当前环境状态
# 对于robosuite任务,为了确保动作的可重复性,需要重置到获取的状态
obs = env.reset_to(state_dict)
results = {} # 用于存储回合结果的字典
video_count = 0 # 视频帧计数器
total_reward = 0. # 累积奖励
got_exception = False # 用于标记是否在回合中遇到异常
success = env.is_success()["task"] # 检查任务是否成功
traj = dict(actions=[], rewards=[], dones=[], states=[], initial_state_dict=state_dict) # 用于存储回合轨迹的数据
if return_obs:
# 如果需要返回观测数据,初始化相应的列表
traj.update(dict(obs=[], next_obs=[]))
- 调用 policy.start_episode() 开始新的策略回合
- 重置环境,获取初始观测 obs
- 如果是在真实机器人上运行,等待用户确认开始
- 初始化用于存储回合数据的字典 traj 和统计信息的字典 stats
2.4 主循环
try:
for step_i in range(horizon): # 开始回合的主循环,最多运行到指定的horizon步
# HACK: some keys on real robot do not have a shape (and then they get frame stacked)
for k in obs:
if len(obs[k].shape) == 1:
obs[k] = obs[k][..., None] # 确保观测数据的形状正确
- 处理观测数据:确保观测数据的形状正确,以避免后续处理中的错误
# get action from policy
t1 = time.time() # 记录动作计算开始时间
act = policy(ob=obs) # 使用策略计算动作
t2 = time.time() # 记录动作计算结束时间
if real and (not env.base_env.controller_type == "JOINT_IMPEDANCE") and (policy.policy.global_config.algo_name != "diffusion_policy"):
# 如果是真实机器人环境,且控制器不是JOINT_IMPEDANCE,且策略不是diffusion_policy
# 为了安全,动作值需要裁剪到[-1, 1]范围内
act = np.clip(act, -1., 1.)
- 计算动作:使用策略根据当前观测计算下一步的动作
- 测量计算时间:记录动作计算前后的时间,以评估计算性能
- 动作裁剪:在真实机器人上,控制器不是 JOINT_IMPEDANCE,且策略不是 diffusion_policy,为了安全,可能需要将动作值裁剪到 [-1, 1] 的范围内
if rate_measure is not None:
# 如果提供了速率测量工具,则测量动作计算速率,而不在环境中执行动作
rate_measure.measure()
print("time: {}s".format(t2 - t1)) # 打印动作计算耗时
# dummy reward and done
r = 0.
done = False
next_obs = obs
else:
# play action
next_obs, r, done, _ = env.step(act) # 在环境中执行动作,获取下一个观测、奖励、是否完成等信息
# compute reward
total_reward += r # 累积奖励
success = env.is_success()["task"] # 检查任务是否成功
- 速率测量:如果提供了 rate_measure,则测量动作计算的速率,并跳过实际的环境交互
- 环境交互:将动作应用于环境,获取下一个观测、奖励、是否完成以及额外信息
- 收集奖励和成功信息:累积总奖励,并检查任务是否成功完成
# visualization
if render:
env.render(mode="human", camera_name=camera_names[0]) # 在屏幕上渲染环境
if video_writer is not None:
if video_count % video_skip == 0:
video_img = []
for cam_name in camera_names:
# 从指定的摄像机获取渲染图像
video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
video_img = np.concatenate(video_img, axis=1) # 将多个摄像机的图像水平拼接
video_writer.append_data(video_img) # 将图像写入视频
video_count += 1 # 增加视频帧计数
- 渲染和视频录制:根据参数,进行屏幕渲染或将图像写入视频文件
# collect transition
traj["actions"].append(act) # 记录动作
traj["rewards"].append(r) # 记录奖励
traj["dones"].append(done) # 记录是否完成
if not real:
traj["states"].append(state_dict["states"]) # 记录环境状态(仅用于模拟环境)
if return_obs:
# 需要将观测数据“反处理”,以便于保存到数据集中
traj["obs"].append(ObsUtils.unprocess_obs_dict(obs))
traj["next_obs"].append(ObsUtils.unprocess_obs_dict(next_obs))
- 收集数据:将动作、奖励、观测等信息存储到 traj 字典中
# break if done or if success
if done or success:
break # 如果任务完成或达到终止条件,退出循环
- 提前结束条件:如果任务成功或达到终止条件,提前结束回合
# update for next iter
obs = deepcopy(next_obs) # 更新当前观测
if not real:
state_dict = env.get_state() # 更新环境状态(仅用于模拟环境)
- 更新观测和状态:为下一步的计算准备
2.5 异常处理
except env.rollout_exceptions as e:
# 捕获在回合中可能发生的异常,记录警告信息
print("WARNING: got rollout exception {}".format(e))
got_exception = True
- 捕获在回合过程中可能发生的异常,记录警告信息
2.6 回合结束处理
stats = dict(
Return=total_reward, # 总奖励
Horizon=(step_i + 1), # 实际运行的步数
Success_Rate=float(success), # 任务成功率
Exception_Rate=float(got_exception), # 异常发生率
time=(time.time() - rollout_timestamp), # 回合耗时
)
if return_obs:
# 将观测数据从列表转换为字典,以便于保存
traj["obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["obs"])
traj["next_obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["next_obs"])
- 计算回合的统计信息,包括总奖励、步数、是否成功、异常发生率和总时间
# list to numpy array
for k in traj:
if k == "initial_state_dict":
continue # 初始状态字典不需要转换
if isinstance(traj[k], dict):
for kp in traj[k]:
traj[k][kp] = np.array(traj[k][kp]) # 将数据转换为NumPy数组
else:
traj[k] = np.array(traj[k])
- 将收集的数据从列表转换为 NumPy 数组,便于后续处理或保存
2.7 返回值
return stats, traj # 返回回合统计信息和轨迹数据
- stats:包含回合统计信息的字典
- traj:包含回合轨迹数据的字典
3. 定义主函数 run_trained_agent
# 主函数,用于根据命令行参数执行策略评估
def run_trained_agent(args):
# 一些参数检查
write_video = (args.video_path is not None) # 判断是否需要录制视频
assert not (args.render and write_video) # 确保不会同时进行屏幕渲染和视频录制
rate_measure = None # 用于测量动作计算速率的工具
if args.hz is not None:
import RobotTeleop
from RobotTeleop.utils import Rate, RateMeasure, Timers
rate_measure = RateMeasure(name="control_rate_measure", freq_threshold=args.hz)
# 加载策略检查点,并获取算法名称
algo_name, ckpt_dict = FileUtils.algo_name_from_checkpoint(ckpt_path=args.agent)
if args.dp_eval_steps is not None:
# 如果指定了dp_eval_steps参数,并且算法是diffusion_policy,则修改配置
assert algo_name == "diffusion_policy"
log_warning("setting @num_inference_steps to {}".format(args.dp_eval_steps))
# 修改配置,然后重新写回ckpt_dict
tmp_config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
with tmp_config.values_unlocked():
if tmp_config.algo.ddpm.enabled:
tmp_config.algo.ddpm.num_inference_timesteps = args.dp_eval_steps
elif tmp_config.algo.ddim.enabled:
tmp_config.algo.ddim.num_inference_timesteps = args.dp_eval_steps
else:
raise Exception("should not reach here")
ckpt_dict['config'] = tmp_config.dump()
# 确定设备(CPU或GPU)
device = TorchUtils.get_torch_device(try_to_use_cuda=True)
# 从检查点恢复策略
policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_dict=ckpt_dict, device=device, verbose=True)
# 读取回合设置
rollout_num_episodes = args.n_rollouts # 回合数量
rollout_horizon = args.horizon # 回合最大步数
config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
if rollout_horizon is None:
# 如果未指定horizon参数,则从配置中获取
rollout_horizon = config.experiment.rollout.horizon
# HACK: assume absolute actions for now if using diffusion policy on real robot
if (algo_name == "diffusion_policy") and EnvUtils.is_real_robot_gprs_env(env_meta=ckpt_dict["env_metadata"]):
ckpt_dict["env_metadata"]["env_kwargs"]["absolute_actions"] = True
# 从保存的检查点创建环境
env, _ = FileUtils.env_from_checkpoint(
ckpt_dict=ckpt_dict,
env_name=args.env,
render=args.render,
render_offscreen=(args.video_path is not None),
verbose=True,
)
# 如果未指定camera_names,则使用默认摄像机配置
if args.camera_names is None:
# We fill in the automatic values
env_type = EnvUtils.get_env_type(env=env)
args.camera_names = DEFAULT_CAMERAS[env_type]
if args.render:
# 屏幕渲染只能支持一个摄像机
assert len(args.camera_names) == 1
is_real_robot = EnvUtils.is_real_robot_env(env=env) or EnvUtils.is_real_robot_gprs_env(env=env)
if is_real_robot:
# 如果是真实机器人环境,记录一些警告信息
need_pause = False
if "env_name" not in ckpt_dict["env_metadata"]["env_kwargs"]:
log_warning("env_name not in checkpoint...proceed with caution...")
need_pause = True
if ckpt_dict["env_metadata"]["env_name"] != "EnvRealPandaGPRS":
# 即使策略是在不同的环境类中收集的,我们也将默认加载EnvRealPandaGPRS类
log_warning("env name in metadata appears to be class ({}) different from EnvRealPandaGPRS".format(ckpt_dict["env_metadata"]["env_name"]))
need_pause = True
if need_pause:
ans = input("continue? (y/n)")
if ans != "y":
exit()
# 如果提供了种子参数,则设置随机种子
if args.seed is not None:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# 可能需要创建视频写入器
video_writer = None
if write_video:
video_writer = imageio.get_writer(args.video_path, fps=20)
# 如果需要写入数据集,打开HDF5文件
write_dataset = (args.dataset_path is not None)
if write_dataset:
data_writer = h5py.File(args.dataset_path, "w")
data_grp = data_writer.create_group("data")
total_samples = 0
rollout_stats = [] # 用于存储每个回合的统计信息
for i in tqdm(range(rollout_num_episodes)): # 使用进度条显示回合进度
try:
stats, traj = rollout(
policy=policy,
env=env,
horizon=rollout_horizon,
render=args.render,
video_writer=video_writer,
video_skip=args.video_skip,
return_obs=(write_dataset and args.dataset_obs),
camera_names=args.camera_names,
real=is_real_robot,
rate_measure=rate_measure,
)
except KeyboardInterrupt:
if is_real_robot:
print("ctrl-C catched, stop execution")
print("env rate measure")
print(env.rate_measure)
ans = input("success? (y / n)")
rollout_stats.append((1 if ans == "y" else 0))
print("*" * 50)
print("have {} success out of {} attempts".format(np.sum(rollout_stats), len(rollout_stats)))
print("*" * 50)
continue
else:
sys.exit(0)
if is_real_robot:
print("TERMINATE WITHOUT KEYBOARD INTERRUPT...")
ans = input("success? (y / n)")
rollout_stats.append((1 if ans == "y" else 0))
continue
rollout_stats.append(stats)
if write_dataset:
# 将数据集写入HDF5文件
ep_data_grp = data_grp.create_group("demo_{}".format(i))
ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]))
ep_data_grp.create_dataset("states", data=np.array(traj["states"]))
ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"]))
ep_data_grp.create_dataset("dones", data=np.array(traj["dones"]))
if args.dataset_obs:
for k in traj["obs"]:
ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k]))
ep_data_grp.create_dataset("next_obs/{}".format(k), data=np.array(traj["next_obs"][k]))
# 记录每个回合的元数据
if "model" in traj["initial_state_dict"]:
ep_data_grp.attrs["model_file"] = traj["initial_state_dict"]["model"] # 该回合的模型XML文件
ep_data_grp.attrs["num_samples"] = traj["actions"].shape[0] # 该回合的样本数量
total_samples += traj["actions"].shape[0] # 更新总样本数
rollout_stats = TensorUtils.list_of_flat_dict_to_dict_of_list(rollout_stats)
avg_rollout_stats = { k : np.mean(rollout_stats[k]) for k in rollout_stats }
avg_rollout_stats["Num_Success"] = np.sum(rollout_stats["Success_Rate"])
avg_rollout_stats["Time_Episode"] = np.sum(rollout_stats["time"]) / 60. # 总耗时(分钟)
avg_rollout_stats["Num_Episode"] = len(rollout_stats["Success_Rate"]) # 总回合数
print("Average Rollout Stats")
stats_json = json.dumps(avg_rollout_stats, indent=4)
print(stats_json)
if args.json_path is not None:
json_f = open(args.json_path, "w")
json_f.write(stats_json)
json_f.close()
if write_video:
video_writer.close()
if write_dataset:
# 写入全局元数据
data_grp.attrs["total"] = total_samples
data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4) # 环境信息
data_writer.close()
print("Wrote dataset trajectories to {}".format(args.dataset_path))
根据命令行参数设置,加载策略和环境,执行多个回合,并根据需要进行渲染、视频录制和数据保存
3.1 参数
- args:解析后的命令行参数
3.2 参数检查和设置
# 一些参数检查
write_video = (args.video_path is not None) # 判断是否需要录制视频
assert not (args.render and write_video) # 确保不会同时进行屏幕渲染和视频录制
rate_measure = None # 用于测量动作计算速率的工具
if args.hz is not None:
import RobotTeleop
from RobotTeleop.utils import Rate, RateMeasure, Timers
rate_measure = RateMeasure(name="control_rate_measure", freq_threshold=args.hz)
- 检查是否同时要求屏幕渲染和视频录制,如果是,抛出错误
- 如果提供了动作计算频率参数 args.hz,导入相关的速率测量工具
***3.3 加载策略检查点
# 加载策略检查点,并获取算法名称
algo_name, ckpt_dict = FileUtils.algo_name_from_checkpoint(ckpt_path=args.agent)
- 使用 FileUtils.algo_name_from_checkpoint 获取策略的算法名称和检查点字典 ckpt_dict
if args.dp_eval_steps is not None:
# 如果指定了dp_eval_steps参数,并且算法是diffusion_policy,则修改配置
assert algo_name == "diffusion_policy"
log_warning("setting @num_inference_steps to {}".format(args.dp_eval_steps))
# 修改配置,然后重新写回ckpt_dict
tmp_config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
with tmp_config.values_unlocked():
if tmp_config.algo.ddpm.enabled:
tmp_config.algo.ddpm.num_inference_timesteps = args.dp_eval_steps
elif tmp_config.algo.ddim.enabled:
tmp_config.algo.ddim.num_inference_timesteps = args.dp_eval_steps
else:
raise Exception("should not reach here")
ckpt_dict['config'] = tmp_config.dump()
- 如果提供了 dp_eval_steps 参数,并且算法是扩散策略,则修改策略的配置,以设置推理步骤数
# 确定设备(CPU或GPU)
device = TorchUtils.get_torch_device(try_to_use_cuda=True)
# 从检查点恢复策略
policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_dict=ckpt_dict, device=device, verbose=True)
- 确定计算设备(CPU 或 GPU)
- 使用 FileUtils.policy_from_checkpoint 从检查点加载策略
3.4 设置回合参数
# 读取回合设置
rollout_num_episodes = args.n_rollouts # 回合数量
rollout_horizon = args.horizon # 回合最大步数
config, _ = FileUtils.config_from_checkpoint(ckpt_dict=ckpt_dict)
if rollout_horizon is None:
# 如果未指定horizon参数,则从配置中获取
rollout_horizon = config.experiment.rollout.horizon
- 确定回合数量 rollout_num_episodes 和回合最大步数 rollout_num_episodes
- 如果没有提供 args.horizon,则从策略配置中获取
# HACK: assume absolute actions for now if using diffusion policy on real robot
if (algo_name == "diffusion_policy") and EnvUtils.is_real_robot_gprs_env(env_meta=ckpt_dict["env_metadata"]):
ckpt_dict["env_metadata"]["env_kwargs"]["absolute_actions"] = True
-
HACK: 如果在真实机器人上使用扩散策略,目前假设采用绝对动作
-
["env_metadata"]["env_kwargs"]["absolute_actions"] 设置为 True
3.5 创建环境
# 从保存的检查点创建环境
env, _ = FileUtils.env_from_checkpoint(
ckpt_dict=ckpt_dict,
env_name=args.env,
render=args.render,
render_offscreen=(args.video_path is not None),
verbose=True,
)
- 使用 FileUtils.env_from_checkpoint 根据检查点字典创建环境
- 根据参数设置渲染选项
# 如果未指定camera_names,则使用默认摄像机配置
if args.camera_names is None:
# We fill in the automatic values
env_type = EnvUtils.get_env_type(env=env)
args.camera_names = DEFAULT_CAMERAS[env_type]
if args.render:
# 屏幕渲染只能支持一个摄像机
assert len(args.camera_names) == 1
- 如果没有提供 ,则使用默认的摄像机名称
is_real_robot = EnvUtils.is_real_robot_env(env=env) or EnvUtils.is_real_robot_gprs_env(env=env)
if is_real_robot:
# 如果是真实机器人环境,记录一些警告信息
need_pause = False
if "env_name" not in ckpt_dict["env_metadata"]["env_kwargs"]:
log_warning("env_name not in checkpoint...proceed with caution...")
need_pause = True
if ckpt_dict["env_metadata"]["env_name"] != "EnvRealPandaGPRS":
# 即使策略是在不同的环境类中收集的,我们也将默认加载EnvRealPandaGPRS类
log_warning("env name in metadata appears to be class ({}) different from EnvRealPandaGPRS".format(ckpt_dict["env_metadata"]["env_name"]))
need_pause = True
if need_pause:
ans = input("continue? (y/n)")
if ans != "y":
exit()
- 检查是否是真实机器人环境,并根据需要记录警告信息
3.6 设置随机种子
# 如果提供了种子参数,则设置随机种子
if args.seed is not None:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
- 如果提供了 args.seed,则设置 NumPy 和 PyTorch 的随机种子,以确保结果的可重复性
3.7 初始化视频和数据写入器
# 可能需要创建视频写入器
video_writer = None
if write_video:
video_writer = imageio.get_writer(args.video_path, fps=20)
# 如果需要写入数据集,打开HDF5文件
write_dataset = (args.dataset_path is not None)
if write_dataset:
data_writer = h5py.File(args.dataset_path, "w")
data_grp = data_writer.create_group("data")
total_samples = 0
- 如果提供了 args.video_path,则创建视频写入器 video_writer
- 如果提供了 args.dataset_path,则打开 HDF5 文件以保存回合数据
3.8 执行回合
rollout_stats = [] # 用于存储每个回合的统计信息
for i in tqdm(range(rollout_num_episodes)): # 使用进度条显示回合进度
try:
stats, traj = rollout(
policy=policy,
env=env,
horizon=rollout_horizon,
render=args.render,
video_writer=video_writer,
video_skip=args.video_skip,
return_obs=(write_dataset and args.dataset_obs),
camera_names=args.camera_names,
real=is_real_robot,
rate_measure=rate_measure,
)
- 初始化一个列表 rollout_stats,用于存储每个回合的统计信息
- 使用 tqdm 创建一个进度条,遍历回合数
- 在每个回合中,调用 rollout 函数执行回合,并捕获统计信息和轨迹数据
except KeyboardInterrupt:
if is_real_robot:
print("ctrl-C catched, stop execution")
print("env rate measure")
print(env.rate_measure)
ans = input("success? (y / n)")
rollout_stats.append((1 if ans == "y" else 0))
print("*" * 50)
print("have {} success out of {} attempts".format(np.sum(rollout_stats), len(rollout_stats)))
print("*" * 50)
continue
else:
sys.exit(0)
if is_real_robot:
print("TERMINATE WITHOUT KEYBOARD INTERRUPT...")
ans = input("success? (y / n)")
rollout_stats.append((1 if ans == "y" else 0))
continue
rollout_stats.append(stats)
- 如果是在真实机器人上运行,捕获 KeyboardInterrupt 异常,并提示用户输入回合是否成功
if write_dataset:
# 将数据集写入HDF5文件
ep_data_grp = data_grp.create_group("demo_{}".format(i))
ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]))
ep_data_grp.create_dataset("states", data=np.array(traj["states"]))
ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"]))
ep_data_grp.create_dataset("dones", data=np.array(traj["dones"]))
if args.dataset_obs:
for k in traj["obs"]:
ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k]))
ep_data_grp.create_dataset("next_obs/{}".format(k), data=np.array(traj["next_obs"][k]))
# 记录每个回合的元数据
if "model" in traj["initial_state_dict"]:
ep_data_grp.attrs["model_file"] = traj["initial_state_dict"]["model"] # 该回合的模型XML文件
ep_data_grp.attrs["num_samples"] = traj["actions"].shape[0] # 该回合的样本数量
total_samples += traj["actions"].shape[0] # 更新总样本数
- 如果需要保存数据,将轨迹数据写入 HDF5 文件
3.9 回合结束处理
rollout_stats = TensorUtils.list_of_flat_dict_to_dict_of_list(rollout_stats)
avg_rollout_stats = { k : np.mean(rollout_stats[k]) for k in rollout_stats }
avg_rollout_stats["Num_Success"] = np.sum(rollout_stats["Success_Rate"])
avg_rollout_stats["Time_Episode"] = np.sum(rollout_stats["time"]) / 60. # 总耗时(分钟)
avg_rollout_stats["Num_Episode"] = len(rollout_stats["Success_Rate"]) # 总回合数
print("Average Rollout Stats")
stats_json = json.dumps(avg_rollout_stats, indent=4)
print(stats_json)
- 计算所有回合的平均统计信息,并打印输出
if args.json_path is not None:
json_f = open(args.json_path, "w")
json_f.write(stats_json)
json_f.close()
if write_video:
video_writer.close()
if write_dataset:
# 写入全局元数据
data_grp.attrs["total"] = total_samples
data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4) # 环境信息
data_writer.close()
print("Wrote dataset trajectories to {}".format(args.dataset_path))
- 如果提供了 args.json_path,将统计信息保存为 JSON 文件
- 关闭视频和数据写入器
4. 解析命令行参数并执行脚本
# 主程序入口
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# 添加命令行参数
# Path to trained model
parser.add_argument(
"--agent",
type=str,
required=True,
help="path to saved checkpoint pth file",
)
# number of rollouts
parser.add_argument(
"--n_rollouts",
type=int,
default=27,
help="number of rollouts",
)
# maximum horizon of rollout, to override the one stored in the model checkpoint
parser.add_argument(
"--horizon",
type=int,
default=None,
help="(optional) override maximum horizon of rollout from the one in the checkpoint",
)
# Env Name (to override the one stored in model checkpoint)
parser.add_argument(
"--env",
type=str,
default=None,
help="(optional) override name of env from the one in the checkpoint, and use\
it for rollouts",
)
# Whether to render rollouts to screen
parser.add_argument(
"--render",
action='store_true',
help="on-screen rendering",
)
# Dump a video of the rollouts to the specified path
parser.add_argument(
"--video_path",
type=str,
default=None,
help="(optional) render rollouts to this video file path",
)
# How often to write video frames during the rollout
parser.add_argument(
"--video_skip",
type=int,
default=5,
help="render frames to video every n steps",
)
# camera names to render
parser.add_argument(
"--camera_names",
type=str,
nargs='+',
default=None,
help="(optional) camera name(s) to use for rendering on-screen or to video",
)
# If provided, an hdf5 file will be written with the rollout data
parser.add_argument(
"--dataset_path",
type=str,
default=None,
help="(optional) if provided, an hdf5 file will be written at this path with the rollout data",
)
# If True and @dataset_path is supplied, will write possibly high-dimensional observations to dataset.
parser.add_argument(
"--dataset_obs",
action='store_true',
help="include possibly high-dimensional observations in output dataset hdf5 file (by default,\
observations are excluded and only simulator states are saved)",
)
# for seeding before starting rollouts
parser.add_argument(
"--seed",
type=int,
default=None,
help="(optional) set seed for rollouts",
)
# Dump a json of the rollout results stats to the specified path
parser.add_argument(
"--json_path",
type=str,
default=None,
help="(optional) dump a json of the rollout results stats to the specified path",
)
# Dump a file with the error traceback at this path. Only created if run fails with an error.
parser.add_argument(
"--error_path",
type=str,
default=None,
help="(optional) dump a file with the error traceback at this path. Only created if run fails with an error.",
)
# TODO: clean up this arg
# If provided, do not run actions in env, and instead just measure the rate of action computation
parser.add_argument(
"--hz",
type=int,
default=None,
help="If provided, do not run actions in env, and instead just measure the rate of action computation and raise warnings if it dips below this threshold",
)
# TODO: clean up this arg
# If provided, set num_inference_timesteps explicitly for diffusion policy evaluation
parser.add_argument(
"--dp_eval_steps",
type=int,
default=None,
help="If provided, set num_inference_timesteps explicitly for diffusion policy evaluation",
)
args = parser.parse_args()
res_str = None
try:
run_trained_agent(args) # 调用主函数,开始策略评估
except Exception as e:
res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
if args.error_path is not None:
# 将错误信息写入指定的错误文件
f = open(args.error_path, "w")
f.write(res_str)
f.close()
raise e # 重新抛出异常
使用 argparse 模块解析命令行参数,支持多种参数配置
如果在执行过程中发生异常,捕获并将错误信息写入指定的错误文件,然后重新抛出异常
4.1 命令行参数
# Path to trained model
parser.add_argument(
"--agent",
type=str,
required=True,
help="path to saved checkpoint pth file",
)
- --agent:必须提供,指定训练好的策略检查点文件路径
# number of rollouts
parser.add_argument(
"--n_rollouts",
type=int,
default=27,
help="number of rollouts",
)
- --n_rollouts:执行的回合数量,默认值为 27
# maximum horizon of rollout, to override the one stored in the model checkpoint
parser.add_argument(
"--horizon",
type=int,
default=None,
help="(optional) override maximum horizon of rollout from the one in the checkpoint",
)
- --horizon:每个回合的最大步数,可选
# Env Name (to override the one stored in model checkpoint)
parser.add_argument(
"--env",
type=str,
default=None,
help="(optional) override name of env from the one in the checkpoint, and use\
it for rollouts",
)
- --env:指定环境名称,覆盖检查点中的环境配置
# Whether to render rollouts to screen
parser.add_argument(
"--render",
action='store_true',
help="on-screen rendering",
)
- --render:启用屏幕渲染
# Dump a video of the rollouts to the specified path
parser.add_argument(
"--video_path",
type=str,
default=None,
help="(optional) render rollouts to this video file path",
)
- --video_path:指定视频保存路径,启用视频录制
# How often to write video frames during the rollout
parser.add_argument(
"--video_skip",
type=int,
default=5,
help="render frames to video every n steps",
)
- --video_skip:每隔多少步保存一帧到视频
# camera names to render
parser.add_argument(
"--camera_names",
type=str,
nargs='+',
default=None,
help="(optional) camera name(s) to use for rendering on-screen or to video",
)
- --camera_names:指定用于渲染的摄像机名称列表
# If provided, an hdf5 file will be written with the rollout data
parser.add_argument(
"--dataset_path",
type=str,
default=None,
help="(optional) if provided, an hdf5 file will be written at this path with the rollout data",
)
- --dataset_path:指定数据集保存路径,启用数据保存
# If True and @dataset_path is supplied, will write possibly high-dimensional observations to dataset.
parser.add_argument(
"--dataset_obs",
action='store_true',
help="include possibly high-dimensional observations in output dataset hdf5 file (by default,\
observations are excluded and only simulator states are saved)",
)
- --dataset_obs:在保存数据时包含观测信息
# for seeding before starting rollouts
parser.add_argument(
"--seed",
type=int,
default=None,
help="(optional) set seed for rollouts",
)
- --seed:设置随机种子
# Dump a json of the rollout results stats to the specified path
parser.add_argument(
"--json_path",
type=str,
default=None,
help="(optional) dump a json of the rollout results stats to the specified path",
)
- --json_path:指定统计信息保存为 JSON 文件的路径
# Dump a file with the error traceback at this path. Only created if run fails with an error.
parser.add_argument(
"--error_path",
type=str,
default=None,
help="(optional) dump a file with the error traceback at this path. Only created if run fails with an error.",
)
- --error_path:指定错误信息保存路径
# TODO: clean up this arg
# If provided, do not run actions in env, and instead just measure the rate of action computation
parser.add_argument(
"--hz",
type=int,
default=None,
help="If provided, do not run actions in env, and instead just measure the rate of action computation and raise warnings if it dips below this threshold",
)
- --hz:指定动作计算频率,用于速率测量
# TODO: clean up this arg
# If provided, set num_inference_timesteps explicitly for diffusion policy evaluation
parser.add_argument(
"--dp_eval_steps",
type=int,
default=None,
help="If provided, set num_inference_timesteps explicitly for diffusion policy evaluation",
)
- --dp_eval_steps:设置扩散策略的推理步骤数
4.2 代码运行
args = parser.parse_args()
res_str = None
try:
run_trained_agent(args) # 调用主函数,开始策略评估
except Exception as e:
res_str = "run failed with error:\n{}\n\n{}".format(e, traceback.format_exc())
if args.error_path is not None:
# 将错误信息写入指定的错误文件
f = open(args.error_path, "w")
f.write(res_str)
f.close()
raise e # 重新抛出异常