NVIDIA Isaac GR00T N1 人形机器人强化学习入门教程(二)详细步骤

系列文章目录

目录

系列文章目录

前言

一、加载推理数据集指南

1.1 加载数据集

1.2 转换数据

二、GR00T 推断

2.1 加载预训练策略

2.2 加载数据集

2.3 了解动作输出

三、微调

3.1 预训练模型

3.2 微调模型


前言


一、加载推理数据集指南

        LeRobot 格式

  • 本教程将展示如何使用我们的数据加载器加载 LeRobot 格式的数据。
  • 我们将以已经转换为 LeRobot 格式的 robot_sim.PickNPlace 数据集为例。
  • 要了解如何转换自己的数据集,请参阅 Gr00t 的 LeRobot.md
from gr00t.utils.misc import any_describe
from gr00t.data.dataset import LeRobotSingleDataset
from gr00t.data.dataset import ModalityConfig
from gr00t.data.schema import EmbodimentTag

1.1 加载数据集

要加载数据集,我们需要定义以下 3 项内容:

  1. 数据集的路径
  2. 模态配置
    1. ModalityConfigs 定义了下游使用的数据模式(如视频、状态、动作、语言),如模型训练或推理。
    2. 每个模态通过 delta_indices(例如,[0] 表示只加载当前帧,[-1,0] 表示加载上一帧和当前帧)指定要加载的帧。
  3. 体现标签(EmbodimentTag
    1. EmbodimentTag 用于指定数据集的具体体现。所有体现标签的列表可在 gr00t.data.embodiment_tags.EmbodimentTag 中找到。
    2. GR00T 的架构针对特定的机器人类型(实施方案)优化了不同的动作头。EmbodimentTag 可以告诉模型使用哪个动作头进行微调和/或推理。在我们的例子中,由于我们使用的是仿人手臂,因此我们指定 EmbodimentTag.GR1_UNIFIED,以获得仿人专用动作头的最佳性能。
import os
import gr00t

# REPO_PATH is the path of the pip install gr00t repo and one level up
REPO_PATH = os.path.dirname(os.path.dirname(gr00t.__file__))
DATA_PATH = os.path.join(REPO_PATH, "demo_data/robot_sim.PickNPlace")

print("Loading dataset... from", DATA_PATH)
# 2. modality configs
modality_configs = {
    "video": ModalityConfig(
        delta_indices=[0],
        modality_keys=["video.ego_view"],
    ),
    "state": ModalityConfig(
        delta_indices=[0],
        modality_keys=[
            "state.left_arm",
            "state.left_hand",
            "state.left_leg",
            "state.neck",
            "state.right_arm",
            "state.right_hand",
            "state.right_leg",
            "state.waist",
        ],
    ),
    "action": ModalityConfig(
        delta_indices=[0],
        modality_keys=[
            "action.left_hand",
            "action.right_hand",
        ],
    ),
    "language": ModalityConfig(
        delta_indices=[0],
        modality_keys=["annotation.human.action.task_description", "annotation.human.validity"],
    ),
}
# 3. gr00t embodiment tag
embodiment_tag = EmbodimentTag.GR1

# load the dataset
dataset = LeRobotSingleDataset(DATA_PATH, modality_configs,  embodiment_tag=embodiment_tag)

print('\n'*2)
print("="*100)
print(f"{' Humanoid Dataset ':=^100}")
print("="*100)

# print the 7th data point
resp = dataset[7]
any_describe(resp)
print(resp.keys())
Initialized dataset robot_sim.PickNPlace with EmbodimentTag.GR1



====================================================================================================
========================================= Humanoid Dataset =========================================
====================================================================================================
{'action.left_hand': 'np: [1, 6] float64',
 'action.right_hand': 'np: [1, 6] float64',
 'annotation.human.action.task_description': ['pick the squash from the '
                                              'counter and place it in the '
                                              'plate'],
 'annotation.human.validity': ['valid'],
 'state.left_arm': 'np: [1, 7] float64',
 'state.left_hand': 'np: [1, 6] float64',
 'state.left_leg': 'np: [1, 6] float64',
 'state.neck': 'np: [1, 3] float64',
 'state.right_arm': 'np: [1, 7] float64',
 'state.right_hand': 'np: [1, 6] float64',
 'state.right_leg': 'np: [1, 6] float64',
 'state.waist': 'np: [1, 3] float64',
 'video.ego_view': 'np: [1, 256, 256, 3] uint8'}
dict_keys(['video.ego_view', 'state.left_arm', 'state.left_hand', 'state.left_leg', 'state.neck', 'state.right_arm', 'state.right_hand', 'state.right_leg', 'state.waist', 'action.left_hand', 'action.right_hand', 'annotation.human.action.task_description', 'annotation.human.validity'])

        在数据中显示图像帧

# show img
import matplotlib.pyplot as plt

images_list = []

for i in range(100):
    if i % 10 == 0:
        resp = dataset[i]
        img = resp["video.ego_view"][0]
        images_list.append(img)


fig, axs = plt.subplots(2, 5, figsize=(20, 10))
for i, ax in enumerate(axs.flat):
    ax.imshow(images_list[i])
    ax.axis("off")
    ax.set_title(f"Image {i}")
plt.tight_layout() # adjust the subplots to fit into the figure area.
plt.show()

1.2 转换数据

        我们还可以对 LeRobotSingleDataset 类中的数据进行一系列转换。下面展示了如何对数据进行转换。

from gr00t.data.transform.base import ComposedModalityTransform
from gr00t.data.transform import VideoToTensor, VideoCrop, VideoResize, VideoColorJitter, VideoToNumpy
from gr00t.data.transform.state_action import StateActionToTensor, StateActionTransform
from gr00t.data.transform.concat import ConcatTransform


video_modality = modality_configs["video"]
state_modality = modality_configs["state"]
action_modality = modality_configs["action"]

# select the transforms you want to apply to the data
to_apply_transforms = ComposedModalityTransform(
    transforms=[
        # video transforms
        VideoToTensor(apply_to=video_modality.modality_keys),
        VideoCrop(apply_to=video_modality.modality_keys, scale=0.95),
        VideoResize(apply_to=video_modality.modality_keys, height=224, width=224, interpolation="linear"),
        VideoColorJitter(apply_to=video_modality.modality_keys, brightness=0.3, contrast=0.4, saturation=0.5, hue=0.08),
        VideoToNumpy(apply_to=video_modality.modality_keys),

        # state transforms
        StateActionToTensor(apply_to=state_modality.modality_keys),
        StateActionTransform(apply_to=state_modality.modality_keys, normalization_modes={
            key: "min_max" for key in state_modality.modality_keys
        }),

        # action transforms
        StateActionToTensor(apply_to=action_modality.modality_keys),
        StateActionTransform(apply_to=action_modality.modality_keys, normalization_modes={
            key: "min_max" for key in action_modality.modality_keys
        }),

        # ConcatTransform
        ConcatTransform(
            video_concat_order=video_modality.modality_keys,
            state_concat_order=state_modality.modality_keys,
            action_concat_order=action_modality.modality_keys,
        ),
    ]
)

        现在看看应用转换后的数据有何不同。

        例如,对状态和操作进行规范化和串联,对视频图像进行裁剪、调整大小和颜色抖动。

dataset = LeRobotSingleDataset(
    DATA_PATH,
    modality_configs,
    transforms=to_apply_transforms,
    embodiment_tag=embodiment_tag
)

# print the 7th data point
resp = dataset[7]
any_describe(resp)
print(resp.keys())
Initialized dataset robot_sim.PickNPlace with EmbodimentTag.GR1
{'action': 'torch: [1, 12] torch.float64 cpu',
 'annotation.human.action.task_description': ['pick the squash from the '
                                              'counter and place it in the '
                                              'plate'],
 'annotation.human.validity': ['valid'],
 'state': 'torch: [1, 44] torch.float64 cpu',
 'video': 'np: [1, 1, 224, 224, 3] uint8'}
dict_keys(['annotation.human.action.task_description', 'annotation.human.validity', 'video', 'state', 'action'])

二、GR00T 推断

        本教程介绍如何使用 GR00T 推理模型,根据测试数据集的观测结果预测行动。

import os
import torch
import gr00t

from gr00t.data.dataset import LeRobotSingleDataset
from gr00t.model.policy import Gr00tPolicy
# change the following paths
MODEL_PATH = "nvidia/GR00T-N1-2B"

# REPO_PATH is the path of the pip install gr00t repo and one level up
REPO_PATH = os.path.dirname(os.path.dirname(gr00t.__file__))
DATASET_PATH = os.path.join(REPO_PATH, "demo_data/robot_sim.PickNPlace")
EMBODIMENT_TAG = "gr1"

device = "cuda" if torch.cuda.is_available() else "cpu"

2.1 加载预训练策略

        策略模型的加载与其他 huggingface 模型一样。

        GR00T 模型中有两个新概念:

  • 模式配置: 这定义了模型使用的字典中的键。(例如:动作、状态、注释、视频)
  • 模态变换(modality_transform): 在数据载入过程中使用的转换序列
from gr00t.experiment.data_config import DATA_CONFIG_MAP


data_config = DATA_CONFIG_MAP["gr1_arms_only"]
modality_config = data_config.modality_config()
modality_transform = data_config.transform()

policy = Gr00tPolicy(
    model_path=MODEL_PATH,
    embodiment_tag=EMBODIMENT_TAG,
    modality_config=modality_config,
    modality_transform=modality_transform,
    device=device,
)

# print out the policy model architecture
print(policy.model)

2.2 加载数据集

        首先,用户需要检查哪些体现标签用于预训练 Gr00tPolicy 预训练模型。

import numpy as np

modality_config = policy.modality_config

print(modality_config.keys())

for key, value in modality_config.items():
    if isinstance(value, np.ndarray):
        print(key, value.shape)
    else:
        print(key, value)
dict_keys(['video', 'state', 'action', 'language'])
video delta_indices=[0] modality_keys=['video.ego_view']
state delta_indices=[0] modality_keys=['state.left_arm', 'state.right_arm', 'state.left_hand', 'state.right_hand']
action delta_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] modality_keys=['action.left_arm', 'action.right_arm', 'action.left_hand', 'action.right_hand']
language delta_indices=[0] modality_keys=['annotation.human.action.task_description']
# Create the dataset
dataset = LeRobotSingleDataset(
    dataset_path=DATASET_PATH,
    modality_configs=modality_config,
    video_backend="decord",
    video_backend_kwargs=None,
    transforms=None,  # We'll handle transforms separately through the policy
    embodiment_tag=EMBODIMENT_TAG,
)
Initialized dataset robot_sim.PickNPlace with gr1

        让我们打印出一个数据并将其可视化

import numpy as np

step_data = dataset[0]

print(step_data)

print("\n\n ====================================")
for key, value in step_data.items():
    if isinstance(value, np.ndarray):
        print(key, value.shape)
    else:
        print(key, value)
{'video.ego_view': array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        ...,

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]]], dtype=uint8), 'state.left_arm': array([[-0.01147083,  0.12207967,  0.04229397, -2.1       , -0.01441445,
        -0.03013532, -0.00384387]]), 'state.right_arm': array([[ 6.74682520e-03, -9.05242648e-02,  8.14010333e-03,
        -2.10000000e+00, -2.22802847e-02,  1.40373494e-02,
         1.04679727e-03]]), 'state.left_hand': array([[0.00067388, 0.00076318, 0.00084776, 0.00069391, 0.00075118,
        0.00485883]]), 'state.right_hand': array([[0.00187319, 0.00216486, 0.002383  , 0.00182536, 0.02100907,
        0.01359221]]), 'action.left_arm': array([[-1.07706680e-02,  1.04934344e-01,  4.15936080e-02,
        -2.09762367e+00, -1.78164827e-02, -3.33876667e-02,
        -9.85330611e-03],
       [-7.91579173e-03,  9.78180763e-02,  3.57741932e-02,
        -2.09809624e+00, -1.47751639e-02, -2.78760650e-02,
        -1.44522492e-02],
       [-5.48170464e-03,  8.97774305e-02,  3.09967920e-02,
        -2.09854670e+00, -1.20841815e-02, -2.37416655e-02,
        -1.86673091e-02],
       [-3.23207919e-03,  8.14548305e-02,  2.66160698e-02,
        -2.09891412e+00, -9.46781410e-03, -2.01521622e-02,
        -2.27768122e-02],
       [-1.03295768e-03,  7.30429523e-02,  2.23631483e-02,
        -2.09919641e+00, -6.85977069e-03, -1.68003940e-02,
        -2.69124769e-02],
       [ 1.18218133e-03,  6.45925486e-02,  1.81263910e-02,
        -2.09939841e+00, -4.25067524e-03, -1.35736444e-02,
        -3.11342315e-02],
       [ 4.01225156e-03,  6.74696222e-02,  9.50461084e-03,
        -2.08267710e+00,  1.30771244e-02, -2.74863628e-02,
        -6.21405030e-03],
       [ 5.70485217e-03,  6.30451006e-02,  5.21468384e-03,
        -2.07986215e+00,  1.71772091e-02, -2.82865559e-02,
        -4.63431140e-03],
       [ 7.39730319e-03,  5.89018901e-02,  1.03944234e-03,
        -2.07728093e+00,  2.08109612e-02, -2.90521862e-02,
        -3.36559193e-03],
       [ 9.12183520e-03,  5.53131414e-02, -2.71213219e-03,
        -2.07524967e+00,  2.33213483e-02, -2.95806653e-02,
        -2.84210159e-03],
       [ 1.09992156e-02,  5.22728198e-02, -6.07222726e-03,
        -2.07378112e+00,  2.45457848e-02, -2.98420447e-02,
        -3.15860443e-03],
       [ 1.30662712e-02,  4.97582401e-02, -9.05391049e-03,
        -2.07285030e+00,  2.45056922e-02, -2.98503259e-02,
        -4.26775116e-03],
       [ 1.52062086e-02,  4.77021053e-02, -1.16503113e-02,
        -2.07234023e+00,  2.35690492e-02, -2.96972474e-02,
        -5.86959074e-03],
       [ 1.73108239e-02,  4.59530903e-02, -1.39242263e-02,
        -2.07209143e+00,  2.21695462e-02, -2.94920861e-02,
        -7.65797543e-03],
       [ 1.92511046e-02,  4.44102571e-02, -1.58636369e-02,
        -2.07202558e+00,  2.06067592e-02, -2.93277413e-02,
        -9.35161262e-03],
       [ 2.07126401e-02,  4.32274177e-02, -1.72182433e-02,
        -2.07203352e+00,  1.92976598e-02, -2.92817179e-02,
        -1.06464981e-02]]), 'action.right_arm': array([[ 2.21627088e-03, -9.37728608e-02, -3.97849900e-02,
        -2.06814219e+00, -8.56739215e-02,  5.79283621e-02,
        -5.77253327e-02],
       [ 4.30219780e-05, -1.04490613e-01, -7.72047014e-02,
        -2.04360052e+00, -1.49008809e-01,  9.11731368e-02,
        -1.09504733e-01],
       [-3.20656176e-03, -1.13106628e-01, -1.10705757e-01,
        -2.02076698e+00, -2.06495159e-01,  1.24586380e-01,
        -1.55423898e-01],
       [-7.43041223e-03, -1.21159332e-01, -1.41850206e-01,
        -1.99755448e+00, -2.60498041e-01,  1.59194985e-01,
        -1.97590390e-01],
       [-1.24604636e-02, -1.29517659e-01, -1.71688775e-01,
        -1.97291110e+00, -3.12480586e-01,  1.95620004e-01,
        -2.37292811e-01],
       [-1.81944020e-02, -1.38630148e-01, -2.00705537e-01,
        -1.94634574e+00, -3.63152746e-01,  2.34135045e-01,
        -2.75149288e-01],
       [ 8.93425381e-02, -4.26527765e-01, -1.26975696e-01,
        -1.93934416e+00,  9.87892979e-02,  2.32846570e-01,
        -2.87245118e-02],
       [ 9.94114267e-02, -4.84456898e-01, -1.19981059e-01,
        -1.92916311e+00,  1.56401124e-01,  2.58826792e-01,
        -5.84780584e-03],
       [ 1.05915501e-01, -5.33586018e-01, -1.09268820e-01,
        -1.92013251e+00,  2.07781017e-01,  2.81349765e-01,
         1.95781464e-02],
       [ 1.05701101e-01, -5.70159727e-01, -9.94966466e-02,
        -1.91462445e+00,  2.47525017e-01,  2.97567721e-01,
         4.42939261e-02],
       [ 9.88279356e-02, -5.93690835e-01, -8.82906547e-02,
        -1.91157439e+00,  2.70735239e-01,  3.12278013e-01,
         7.16265080e-02],
       [ 8.50292359e-02, -6.07675457e-01, -8.17918138e-02,
        -1.91183875e+00,  2.81771715e-01,  3.23118613e-01,
         1.02626036e-01],
       [ 6.71502302e-02, -6.18867876e-01, -7.60968035e-02,
        -1.91468068e+00,  2.81581396e-01,  3.40363262e-01,
         1.35024227e-01],
       [ 4.65002758e-02, -6.27704616e-01, -7.52606241e-02,
        -1.91926019e+00,  2.75164718e-01,  3.59271051e-01,
         1.67061231e-01],
       [ 2.45882211e-02, -6.33134721e-01, -7.88275961e-02,
        -1.92399113e+00,  2.61009603e-01,  3.81247479e-01,
         1.98397058e-01],
       [ 5.03588613e-03, -6.39530345e-01, -8.39274535e-02,
        -1.92642545e+00,  2.47624030e-01,  4.06113278e-01,
         2.24670301e-01]]), 'action.left_hand': array([[-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ]]), 'action.right_hand': array([[-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ],
       [-1.5, -1.5, -1.5, -1.5, -3. ,  3. ]]), 'annotation.human.action.task_description': ['pick the squash from the counter and place it in the plate']}


 ====================================
video.ego_view (1, 256, 256, 3)
state.left_arm (1, 7)
state.right_arm (1, 7)
state.left_hand (1, 6)
state.right_hand (1, 6)
action.left_arm (16, 7)
action.right_arm (16, 7)
action.left_hand (16, 6)
action.right_hand (16, 6)
annotation.human.action.task_description ['pick the squash from the counter and place it in the plate']

        让我们只绘制 “右手 ”状态和动作数据,看看效果如何。同时显示右手状态的图像。

import matplotlib.pyplot as plt

traj_id = 0
max_steps = 150

state_joints_across_time = []
gt_action_joints_across_time = []
images = []

sample_images = 6

for step_count in range(max_steps):
    data_point = dataset.get_step_data(traj_id, step_count)
    state_joints = data_point["state.right_arm"][0]
    gt_action_joints = data_point["action.right_arm"][0]
    
   
    state_joints_across_time.append(state_joints)
    gt_action_joints_across_time.append(gt_action_joints)

    # We can also get the image data
    if step_count % (max_steps // sample_images) == 0:
        image = data_point["video.ego_view"][0]
        images.append(image)

# Size is (max_steps, num_joints == 7)
state_joints_across_time = np.array(state_joints_across_time)
gt_action_joints_across_time = np.array(gt_action_joints_across_time)


# Plot the joint angles across time
fig, axes = plt.subplots(nrows=7, ncols=1, figsize=(8, 2*7))

for i, ax in enumerate(axes):
    ax.plot(state_joints_across_time[:, i], label="state joints")
    ax.plot(gt_action_joints_across_time[:, i], label="gt action joints")
    ax.set_title(f"Joint {i}")
    ax.legend()

plt.tight_layout()
plt.show()


# Plot the images in a row
fig, axes = plt.subplots(nrows=1, ncols=sample_images, figsize=(16, 4))

for i, ax in enumerate(axes):
    ax.imshow(images[i])
    ax.axis("off")
    

        现在,我们可以从预训练的检查点运行策略了。

predicted_action = policy.get_action(step_data)
for key, value in predicted_action.items():
    print(key, value.shape)
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
action.left_arm (16, 7)
action.right_arm (16, 7)
action.left_hand (16, 6)
action.right_hand (16, 6)

2.3 了解动作输出

        动作输出中的每个关节都有一个 (16, N) 的形状,其中 N 是关节的自由度。

  • 16 代表行动范围(时间步 t、t+1、t+2、...、t+15 的预测值)

        每只手臂(左臂和右臂)

  • 7 个手臂关节
    • 肩关节俯仰
    • 肩部滚动
    • 肩偏航
    • 肘关节俯仰
    • 腕关节偏航
    • 腕滚动
    • 腕关节俯仰

        每只手(左手和右手)

  • 6 个手指关节
    • 小指
    • 无名指
    • 中指
    • 食指
    • 拇指旋转
    • 拇指弯曲

        腰部

  • 3 个关节
    • 躯干腰部偏转
    • 躯干腰部俯仰
    • 躯干腰部滚动

三、微调

        本教程说明了如何使用相同的实施方案,在训练后数据集上对 GR00T-N1 预训练检查点进行微调。这展示了后训练的好处,将通用模型转变为专业模型,并证明了性能的提高。

        在本教程中,我们将使用 demo_data 文件夹中的演示数据集 robot_sim.PickNPlace。

        我们将首先加载预训练模型,并在数据集上对其进行评估。然后,我们将在数据集上对模型进行微调,并评估其性能。

3.1 预训练模型

from gr00t.utils.eval import calc_mse_for_single_trajectory
import warnings
from gr00t.experiment.data_config import DATA_CONFIG_MAP
from gr00t.model.policy import Gr00tPolicy
from gr00t.data.schema import EmbodimentTag
from gr00t.data.dataset import LeRobotSingleDataset
import numpy as np
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

warnings.simplefilter("ignore", category=FutureWarning)
/home/youliang/miniconda3/envs/groot-release/lib/python3.10/site-packages/albumentations/__init__.py:13: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.18). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
  check_for_updates()
2025-03-18 13:27:04.588611: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-18 13:27:04.610590: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-18 13:27:04.610610: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-18 13:27:04.611353: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-18 13:27:04.615698: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-18 13:27:04.959771: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
PRE_TRAINED_MODEL_PATH = "nvidia/GR00T-N1-2B"
EMBODIMENT_TAG = EmbodimentTag.GR1
DATASET_PATH = "../demo_data/robot_sim.PickNPlace"


data_config = DATA_CONFIG_MAP["gr1_arms_only"]
modality_config = data_config.modality_config()
modality_transform = data_config.transform()


pre_trained_policy = Gr00tPolicy(
    model_path=PRE_TRAINED_MODEL_PATH,
    embodiment_tag=EMBODIMENT_TAG,
    modality_config=modality_config,
    modality_transform=modality_transform,
    device=device,
)

dataset = LeRobotSingleDataset(
    dataset_path=DATASET_PATH,
    modality_configs=modality_config,
    video_backend="decord",
    video_backend_kwargs=None,
    transforms=None,  # We'll handle transforms separately through the policy
    embodiment_tag=EMBODIMENT_TAG,
)


mse = calc_mse_for_single_trajectory(
    pre_trained_policy,
    dataset,
    traj_id=0,
    modality_keys=["right_arm", "right_hand"],   # we will only evaluate the right arm and right hand
    steps=150,
    action_horizon=16,
    plot=True
)

print("MSE loss for trajectory 0:", mse)
Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]
Loading pretrained dual brain from /home/youliang/.cache/huggingface/hub/models--nvidia--GR00T-N1-2B/snapshots/32e1fd2507f7739fad443e6b449c8188e0e02fcb
Tune backbone vision tower: True
Tune backbone LLM: False
Tune action head projector: True
Tune action head DiT: True
Model not found or avail in the huggingface hub. Loading from local path: /home/youliang/.cache/huggingface/hub/models--nvidia--GR00T-N1-2B/snapshots/32e1fd2507f7739fad443e6b449c8188e0e02fcb
Total number of DiT parameters:  537803776
Tune action head projector: True
Tune action head diffusion model: True
Some weights of the model checkpoint at /home/youliang/.cache/huggingface/hub/models--nvidia--GR00T-N1-2B/snapshots/32e1fd2507f7739fad443e6b449c8188e0e02fcb were not used when initializing GR00T_N1: ['action_head.decode_layer.bias', 'action_head.decode_layer.weight']
- This IS expected if you are initializing GR00T_N1 from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GR00T_N1 from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Tune action head projector: True
Tune action head diffusion model: True
Initialized dataset robot_sim.PickNPlace with EmbodimentTag.GR1
inferencing at step:  0
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 3.6764835882576428

很好!我们可以看到预测行动和地面实况行动。预测的行动并不完美,但与地面实况行动很接近。这表明预训练的检查点运行良好。

现在,让我们对 10 个随机轨迹进行采样,计算平均 MSE,以获得更详细的结果。

total_trajectories = len(dataset.trajectory_lengths)

print("Total trajectories:", total_trajectories)

sampled_trajectories = np.random.choice(total_trajectories, 10)
print("Sampled trajectories:", sampled_trajectories)

all_mses = []

for traj_id in sampled_trajectories:
    mse = calc_mse_for_single_trajectory(
        pre_trained_policy,
        dataset,
        traj_id=traj_id,
        modality_keys=["right_arm", "right_hand"],   # we will only evaluate the right arm and right hand
        steps=150,
        action_horizon=16,
        plot=False
    )
    print(f"Trajectory {traj_id} MSE: {mse:.4f}")
    
    all_mses.append(mse)

print("====================================")
print("Mean MSE:", np.mean(all_mses))
print("Std MSE:", np.std(all_mses))
Total trajectories: 5
Sampled trajectories: [2 1 0 1 3 2 3 0 4 3]
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 3.20396729700109
Trajectory 2 MSE: 3.2040
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 2.2953596460669803
Trajectory 1 MSE: 2.2954
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 4.289567449236214
Trajectory 0 MSE: 4.2896
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 2.8355197990968897
Trajectory 1 MSE: 2.8355
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 2.7388654553913465
Trajectory 3 MSE: 2.7389
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 3.007079749042383
Trajectory 2 MSE: 3.0071
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 2.374824835012358
Trajectory 3 MSE: 2.3748
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 3.4457704221466785
Trajectory 0 MSE: 3.4458
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 4.591438530501049
Trajectory 4 MSE: 4.5914
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 3.7617346859045115
Trajectory 3 MSE: 3.7617
====================================
Mean MSE: 3.2544127869399504
Std MSE: 0.7315471909013109

3.2 微调模型

现在我们将在数据集上对模型进行微调。在不详细介绍微调过程的前提下,我们将使用 gr00t_finetune.py 脚本对模型进行微调。您可以运行以下命令对模型进行微调。

python scripts/gr00t_finetune.py --dataset-path ./demo_data/robot_sim.PickNPlace --num-gpus 1 --max-steps 500 --output-dir /tmp/gr00t-1/finetuned-model --data-config gr1_arms_only

要获得可用参数的完整列表,可以运行 python scripts/gr00t_finetune.py --help 命令。

脚本会将微调后的模型保存在 /tmp/gr00t-1/finetuned-model 目录中。我们将加载包含 500 个检查点步骤的微调模型。

3.2.1 评估微调模型

现在,我们可以通过在数据集上运行策略来评估微调模型,看看它的性能如何。我们将使用效用函数来评估数据集上的策略。这与之前在 1_pretrained_model.ipynb 中的教程类似

from gr00t.utils.eval import calc_mse_for_single_trajectory
import warnings

finetuned_model_path = "/tmp/gr00t-1/finetuned-model/checkpoint-500"
finetuned_policy = Gr00tPolicy(
    model_path=finetuned_model_path,
    embodiment_tag="new_embodiment",
    modality_config=modality_config,
    modality_transform=modality_transform,
    device=device,
)

warnings.simplefilter("ignore", category=FutureWarning)

mse = calc_mse_for_single_trajectory(
    finetuned_policy,
    dataset,
    traj_id=0,
    modality_keys=["right_arm", "right_hand"],   # we will only evaluate the right arm and right hand
    steps=150,
    action_horizon=16,
    plot=True
)

print("MSE loss for trajectory 0:", mse)
Model not found or avail in the huggingface hub. Loading from local path: /tmp/gr00t-1/finetuned-model-1/checkpoint-500
Loading pretrained dual brain from /tmp/gr00t-1/finetuned-model-1/checkpoint-500
Tune backbone vision tower: True
Tune backbone LLM: False
Tune action head projector: True
Tune action head DiT: True
Model not found or avail in the huggingface hub. Loading from local path: /tmp/gr00t-1/finetuned-model-1/checkpoint-500
Total number of DiT parameters:  537803776
Tune action head projector: True
Tune action head diffusion model: True
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.07it/s]
Tune action head projector: True
Tune action head diffusion model: True
inferencing at step:  0
inferencing at step:  16
inferencing at step:  32
inferencing at step:  48
inferencing at step:  64
inferencing at step:  80
inferencing at step:  96
inferencing at step:  112
inferencing at step:  128
inferencing at step:  144
Unnormalized Action MSE across single traj: 0.005610098706850138

耶!我们对模型进行了微调,并在数据集上对其进行了评估。我们可以看到,模型已经学会了这项任务,并且比预先训练好的模型能更好地完成任务。 

### Nvidia Isaac Sim 2024入门教程 #### 安装与设置环境 为了开始使用Nvidia Isaac Sim 2024版,首先需要确保计算机满足最低硬件要求并安装必要的软件依赖项。建议的操作系统为Ubuntu 20.04 LTS或更高版本,并且推荐配备支持CUDA的GPU以加速仿真性能[^1]。 完成上述准备工作之后,可以从[NVIDIA官网](https://developer.nvidia.com/)下载Isaac Sim 4.0及其后续更新版本。此过程通常涉及注册开发者账号以及同意服务条款。一旦获取到安装文件,按照官方文档指示逐步执行安装命令即可成功部署Isaac Sim环境。 ```bash # 更新包列表并安装依赖库 sudo apt-get update && sudo apt-get install -y \ build-essential \ cmake \ git \ libgl1-mesa-dev \ libglfw3-dev \ python3-pip ``` #### 创建首个机器人模拟场景 启动Isaac Sim应用程序后,默认会进入Omniverse界面,在这里可以创建新的项目或者打开已有模板来快速上手。对于初学者来说,选择预设好的机器人模型(如Unitree Go2)作为起点是非常有帮助的选择之一[^2]。 在编辑器内调整机器人的物理属性、传感器配置以及其他参数设定之前,先熟悉下工具栏上的基本操作按钮,比如视角切换、物体移动等基础交互方式。这些技能有助于更高效地搭建复杂的实验环境。 #### 编写自定义行为逻辑 除了图形化设计之外,编写Python脚本来实现特定任务也是不可或缺的一部分。利用内置API接口可以直接操控虚拟世界里的对象动作序列或是响应外部输入事件触发相应反馈机制。下面给出一段简单的例子用于让机器人沿直线行走: ```python from omni.isaac.kit import SimulationApp simulation_app = SimulationApp({"headless": False}) import numpy as np from omni.isaac.dynamic_control import _dynamic_control dc = _dynamic_control.acquire_dynamic_control_interface() # 获取目标Actor句柄 actor_handle = dc.get_actor("/World/Go2") for i in range(100): # 循环次数可以根据实际需求修改 position = dc.get_rigid_body_pose(actor_handle).p new_position = [position.x + 0.1 * i, position.y, position.z] target_transform = _dynamic_control.Transform() target_transform.p = _dynamic_control.Vector3(new_position) dc.wake_up_actor(actor_handle) dc.set_rigid_body_pose(actor_handle, target_transform) simulation_app.close() ``` #### 测试与优化 最后一步就是反复测试所编写的程序效果如何,并不断迭代改进直至达到预期目的为止。在这个过程中可能会遇到各种各样的挑战,例如运动学计算错误或者是感知算法失效等问题都需要耐心排查解决办法。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值