robomimic应用教程(二)——策略运行与评估

得到训练好的pth后,下一并将其进行部署及效果评估

可以在jupyter notebook中进行此操作,文件为robomimic文件夹中的examples/notebooks/run_policy.ipynb

本文采用pycharm调试

该脚本用于在环境中评估策略,主要包括从model zoo下载checkpoint,在pytorch中加载checkpoint,并运行评估policy

目录

一、参数说明

二、Terminal运行

1. 执行评估策略

2. 保存方法1

3. 保存方法2

三、逐步运行与解析(run a trained policy and visualize the rollout)

1. 库引用

2. 下载policy checkpoint

3. 加载trained policy

4. 创建 rollout 环境

5. 定义 rollout 循环

6. 运行policy

7. 可视化 rollout

8.运行结果


一、参数说明

"""
The main script for evaluating a policy in an environment.

Args:
    agent (str): path to saved checkpoint pth file

    horizon (int): if provided, override maximum horizon of rollout from the one 
        in the checkpoint

    env (str): if provided, override name of env from the one in the checkpoint,
        and use it for rollouts

    render (bool): if flag is provided, use on-screen rendering during rollouts

    video_path (str): if provided, render trajectories to this video file path

    video_skip (int): render frames to a video every @video_skip steps

    camera_names (str or [str]): camera name(s) to use for rendering on-screen or to video

    dataset_path (str): if provided, an hdf5 file will be written at this path with the
        rollout data

    dataset_obs (bool): if flag is provided, and @dataset_path is provided, include 
        possible high-dimensional observations in output dataset hdf5 file (by default,
        observations are excluded and only simulator states are saved).

    seed (int): if provided, set seed for rollouts

Example usage:

    # Evaluate a policy with 50 rollouts of maximum horizon 400 and save the rollouts to a video.
    # Visualize the agentview and wrist cameras during the rollout.
    
    python run_trained_agent.py --agent /path/to/model.pth \
        --n_rollouts 50 --horizon 400 --seed 0 \
        --video_path /path/to/output.mp4 \
        --camera_names agentview robot0_eye_in_hand 

    # Write the 50 agent rollouts to a new dataset hdf5.

    python run_trained_agent.py --agent /path/to/model.pth \
        --n_rollouts 50 --horizon 400 --seed 0 \
        --dataset_path /path/to/output.hdf5 --dataset_obs 

    # Write the 50 agent rollouts to a new dataset hdf5, but exclude the dataset observations
    # since they might be high-dimensional (they can be extracted again using the
    # dataset_states_to_obs.py script).

    python run_trained_agent.py --agent /path/to/model.pth \
        --n_rollouts 50 --horizon 400 --seed 0 \
        --dataset_path /path/to/output.hdf5
"""

agent (str): 已保存的检查点模型文件路径(.pth文件)。

horizon (int):如果提供,将覆盖检查点中的最大回合长度,即在评估中运行多少步

env (str): 如果提供,将覆盖检查点中保存的环境名称,用于运行策略时创建新的环境

render (bool):如果提供该标志,在每次回合执行时显示屏幕上的实时渲染

video_path (str):如果提供,将回合过程录制为视频并保存到指定的路径

video_skip (int):每隔 @video_skip 步渲染一次帧到视频中

camera_names (str 或 [str]):指定用于渲染的相机名称。

dataset_path (str):如果提供,将回合数据写入到指定路径的hdf5文件中

dataset_obs (bool):如果flag及@dataset_path提供,hdf5文件中将包含可能的高维观测数据(默认情况下,仅保存模拟器状态)

seed (int): 如果提供,设置回合的随机种子

二、Terminal运行

保存回合到视频并渲染相机视图

评估一个策略,50次滚动,最大地平线400,并将滚动保存到视频中

在过程中可视化agentview和手腕摄像头

1. 执行评估策略

进行50次回合,每次最多运行400步,并将回合保存为视频文件,回合过程中显示agentview和wrist相机视角的画面

python run_trained_agent.py --agent /path/to/model.pth \
    --n_rollouts 50 --horizon 400 --seed 0 \
    --video_path /path/to/output.mp4 \
    --camera_names agentview robot0_eye_in_hand

2. 保存方法1

将50次回合保存到hdf5数据集文件中

python run_trained_agent.py --agent /path/to/model.pth \
    --n_rollouts 50 --horizon 400 --seed 0 \
    --dataset_path /path/to/output.hdf5 --dataset_obs

3. 保存方法2

将50次回合保存到hdf5数据集文件中,但不包含数据观测数据,因为这些数据可能是高维的

可以使用dataset_states_to_obs.py脚本再次提取这些观测数据

python run_trained_agent.py --agent /path/to/model.pth \
    --n_rollouts 50 --horizon 400 --seed 0 \
    --dataset_path /path/to/output.hdf5

三、逐步运行与解析(run a trained policy and visualize the rollout)

官方提供了python代码的示例,先看一遍相关示例,再进一步解析

示例采用jupyter演示,库环境默认已安装好robomimic和robosuite

1. 库引用

import argparse
import json
import h5py
import imageio
import numpy as np
import os
from copy import deepcopy

import torch

import robomimic
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.envs.env_base import EnvBase
from robomimic.algo import RolloutPolicy

import urllib.request

2. 下载policy checkpoint

从model zoo中下载pretrained model

此处简单的说明一下 checkpoint,在深度学习模型中,checkpoint 是指在训练过程中保存的模型状态,通常包含模型的参数(权重和偏置)、优化器的状态以及其他相关的训练信息

在训练过程中定期保存模型 checkpoint,就可以在需要时恢复训练或用于模型评估和推理

模型的参数(权重和偏置)文件:在 TensorFlow 中通常是 .cpkt 文件,在 PyTorch 中通常是 .pt.pth 文件

# Get pretrained checkpooint from the model zoo

ckpt_path = "lift_ph_low_dim_epoch_1000_succ_100.pth"
# Lift (Proficient Human)
urllib.request.urlretrieve(
    "http://downloads.cs.stanford.edu/downloads/rt_benchmark/model_zoo/lift/bc_rnn/lift_ph_low_dim_epoch_1000_succ_100.pth",
    filename=ckpt_path
)

assert os.path.exists(ckpt_path)

3. 加载trained policy

调用 policy_from_checkpoint 函数,从 checkpoint 中构建正确的模型并加载训练好的权重,也可以手动加载 checkpoint

device = TorchUtils.get_torch_device(try_to_use_cuda=True)

# restore policy
policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=ckpt_path, device=device, verbose=True)

4. 创建 rollout 环境

此处简单的说明一下 rollout,直接翻译为 “推演或者模拟”,通常是智能体和环境以及模型交互的过程中产生的一系列的交互历史轨迹,通过收集数据,来评估或改进当前的策略

一个 rollout 可以包含一个或多个完整的episodes,或者只是一个episode的一部分(在实际应用中,通常一个rollout只包含一个episode的数据)

policy checkpoint 包含足够的信息来重新创建训练它的环境(也可以手动创建环境)

# create environment from saved checkpoint
env, _ = FileUtils.env_from_checkpoint(
    ckpt_dict=ckpt_dict, 
    render=False, # we won't do on-screen rendering in the notebook
    render_offscreen=True, # render to RGB images for video
    verbose=True,
)

5. 定义 rollout 循环

定义主 rollout 循环,该循环将运行 policy 到目标 horizon,并将 rollout 写入视频(可选)

def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, camera_names=None):
    """
    Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video, 
    and returns the rollout trajectory.
    Args:
        policy (instance of RolloutPolicy): policy loaded from a checkpoint
        env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
        horizon (int): maximum horizon for the rollout
        render (bool): whether to render rollout on-screen
        video_writer (imageio writer): if provided, use to write rollout to video
        video_skip (int): how often to write video frames
        camera_names (list): determines which camera(s) are used for rendering. Pass more than
            one to output a video with multiple camera views concatenated horizontally.
    Returns:
        stats (dict): some statistics for the rollout - such as return, horizon, and task success
    """
    assert isinstance(env, EnvBase)
    assert isinstance(policy, RolloutPolicy)
    assert not (render and (video_writer is not None))

    policy.start_episode()
    obs = env.reset()
    state_dict = env.get_state()

    # hack that is necessary for robosuite tasks for deterministic action playback
    obs = env.reset_to(state_dict)

    results = {}
    video_count = 0  # video frame counter
    total_reward = 0.
    try:
        for step_i in range(horizon):

            # get action from policy
            act = policy(ob=obs)

            # play action
            next_obs, r, done, _ = env.step(act)

            # compute reward
            total_reward += r
            success = env.is_success()["task"]

            # visualization
            if render:
                env.render(mode="human", camera_name=camera_names[0])
            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = []
                    for cam_name in camera_names:
                        video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                    video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
                    video_writer.append_data(video_img)
                video_count += 1

            # break if done or if success
            if done or success:
                break

            # update for next iter
            obs = deepcopy(next_obs)
            state_dict = env.get_state()

    except env.rollout_exceptions as e:
        print("WARNING: got rollout exception {}".format(e))

    stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success))

    return stats

6. 运行policy

此处请仔细看,和 terminal 运行指令相同,赋予参数

rollout_horizon = 400
np.random.seed(0)
torch.manual_seed(0)
video_path = "rollout.mp4"
video_writer = imageio.get_writer(video_path, fps=20)

运行主函

stats = rollout(
    policy=policy, 
    env=env, 
    horizon=rollout_horizon, 
    render=False, 
    video_writer=video_writer, 
    video_skip=5, 
    camera_names=["agentview"]
)
print(stats)
video_writer.close()

7. 可视化 rollout

from IPython.display import Video
Video(video_path)

8.运行结果

下载了数据集,生成了视频,看一下是夹取动作

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值