RT1复现(四)

RT1复现(四)

这部分工作是测试部分,language-table测试有5个任务,测试过程随机生成颜色形状块,然后生成相对应的指令,每个任务有对应的成功标准,比如多少m之内就算成功

Language-Table中定义了五个模拟评估任务系列(涵盖了696个独特的任务条件),每个任务系列都有一个手动定义的成功标准:

  • block2block: 推一个块到另一个块。成功是源块和目标块之间的阈值距离。有56个唯一的任务条件(8个源块x7个目标块)。

  • block2abs: 将一个块推到板上的绝对位置:左上角、中上、右上角、左中、中心、右中、左下角、中下、右下角。成功是块和目标位置之间的阈值距离。有72个唯一的任务条件(8个块x9个位置)。

  • block2rel: 将一个块推到相对偏移的位置:左、右、上、下、左上、右上、左下、右下。成功是块和不可见目标偏移位置之间的阈值距离。有64个唯一的任务条件(8个块x8个偏移方向)。

  • block2blockrel: 将一个块推到另一个块的相对偏移位置:左侧、右侧、顶部、底部、左上侧、右上侧、左下侧、右下侧。成功是源块和目标块的不可见目标偏移位置之间的阈值距离。有448个唯一的任务条件(8个源块x7个目标块x8个偏移方向)。

  • separate: 将两个块分开。成功是两个块之间的阈值距离。有56个唯一的任务条件(8个源块x7个目标块)。

下载language-table项目

git clone https://github.com/google-research/language-table.git

然后需要修改eval/weappers.py

# coding=utf-8
# Copyright 2023 The Language Tale Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Environment wrappers."""

from typing import Any, Optional

from language_table.common import rt1_tokenizer
import numpy as np
import tensorflow as tf
from tf_agents.environments import wrappers


class UseTokenWrapper(wrappers.PyEnvironmentBaseWrapper):
  """Environment wrapper that adds CLIP tokens to the obs."""

  def __init__(self, env, context_length = 77):
    """Centrally crops an image from a dict observation."""
    super(UseTokenWrapper, self).__init__(env)
    self._context_length = context_length
    # vocab_lookup = clip_tokenizer.create_vocab()
    # self._tokenizer = clip_tokenizer.ClipTokenizer(vocab_lookup)
    self._current_tokens = None

  def _reset(self):
    time_step = self._env.reset()
    self._current_tokens = self._tokenize(time_step.observation['instruction'])
    new_obs = time_step.observation
    new_obs['instruction_tokenized_use'] = self._current_tokens
    return time_step._replace(observation=new_obs)

  def _step(self, action):
    time_step = self._env.step(action)
    new_obs = time_step.observation
    new_obs['instruction_tokenized_use'] = tf.convert_to_tensor(self._current_tokens)
    return time_step._replace(observation=new_obs)

  def _tokenize(self, instruction):
    
    bytes_list = instruction
    non_zero = bytes_list[np.where(bytes_list != 0)]
    if non_zero.shape[0] == 0:
      decoded = ''
    else:
      bytes_list = bytes(non_zero.tolist())
      decoded = bytes_list.decode('utf-8')
      
    tokens = rt1_tokenizer.tokenize_text(decoded)
    return tokens


class CentralCropImageWrapper(wrappers.PyEnvironmentBaseWrapper):
  """Environment wrapper that crops image observations."""

  def __init__(self,
               env,
               target_height,
               target_width,
               random_crop_factor = None):
    """Centrally crops an image from a dict observation."""
    super(CentralCropImageWrapper, self).__init__(env)
    self._target_height = target_height
    self._target_width = target_width
    self._random_crop_factor = random_crop_factor

  def _reset(self):
    time_step = self._env.reset()
    new_obs = self._crop_observation(time_step.observation)
    return time_step._replace(observation=new_obs)

  def _step(self, action):
    time_step = self._env.step(action)
    new_obs = self._crop_observation(time_step.observation)
    return time_step._replace(observation=new_obs)

  def _crop_observation(self, obs):
    new_obs = obs
    image = obs['rgb']
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    # Apply average crop augmentation.
    image = crop_test_image(image, self._random_crop_factor)
    image = resize_images(image, self._target_height, self._target_width)
    new_obs['rgb_sequence'] = tf.convert_to_tensor(image)
    return new_obs


def crop_test_image(images, random_crop_factor):
  """Get the average crop applied during crop training augmentation."""

  def take_center_crop_consistent_with_random(im):
    im_raw_size = tf.shape(im)
    raw_height = tf.cast(im_raw_size[0], tf.float32)
    raw_width = tf.cast(im_raw_size[1], tf.float32)
    scaled_height = raw_height * random_crop_factor
    scaled_width = raw_width * random_crop_factor
    offset_height = tf.cast((raw_height - scaled_height) // 2, tf.int32)
    offset_width = tf.cast((raw_width - scaled_width) // 2, tf.int32)
    target_height = tf.cast(scaled_height, tf.int32)
    target_width = tf.cast(scaled_width, tf.int32)
    im = tf.image.crop_to_bounding_box(
        im,
        offset_height=offset_height,
        offset_width=offset_width,
        target_height=target_height,
        target_width=target_width)
    return im

  if len(images.shape) == 3:
    return take_center_crop_consistent_with_random(images)
  images = tf.map_fn(take_center_crop_consistent_with_random, images)
  return images


def resize_images(images, target_height=None, target_width=None):
  """Resizes images to target_height, target_width."""
  assert target_height
  assert target_width

  # Resize to target height and width.
  def _resize(im):
    return tf.image.resize(im, [target_height, target_width])

  images = _resize(images)

  return images

rt1_tokenizer.py

import tensorflow_hub as hub


def tokenize_text(text):
  """Tokenizes the input text given a tokenizer."""
  embed = hub.load("/mnt/ve_share2/zy/Universal_Sentence_Encoder")
  tokens = embed([text])
  return tokens

# text ='push the blue triangle closer to yellow heart'
# print(tokenize_text(text))

eval.py (参考eval/main.py进行修改)

import imageio
import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))  # 当前文件所在目录
parent_dir = os.path.dirname(current_dir) # 上一级目录
# parent_dir = os.path.dirname(parent_dir1)
sys.path.append(parent_dir) 
from distribute_train import get_args,create_model,create_train_dataset
from tf_agents.specs import tensor_spec
import tensorflow as tf
import collections
from collections.abc import Sequence
import os
from absl import app
from absl import flags
from absl import logging
import jax
import numpy as np

# from language_table.common import rt1_tokenizer
from language_table.environments import blocks
from language_table.environments import language_table
from language_table.environments.oracles import push_oracle_rrt_slowdown
from language_table.environments.rewards import block2absolutelocation
from language_table.environments.rewards import block2block
from language_table.environments.rewards import block2block_relative_location
from language_table.environments.rewards import block2relativelocation
from language_table.environments.rewards import separate_blocks
from language_table.eval import wrappers as env_wrappers
from language_table.train import policy as jax_policy
from ml_collections import config_flags

import tensorflow as tf
import tensorflow_hub as hub
from tf_agents.environments import gym_wrapper
from tf_agents.environments import wrappers as tfa_wrappers

_CONFIG = config_flags.DEFINE_config_file(
    "config","/mnt/ve_share2/zy/robotics_transformer/language_table/train/configs/language_table_sim_local.py", "Training configuration.", lock_config=True)
_WORKDIR = flags.DEFINE_string("workdir","/mnt/ve_share2/zy/eval", "working dir")

def get_ckpt_model():
    time_sequence_length = 6  # 常量,来自论文每次预测使用6张图片
    args = get_args()
    gpus = tf.config.experimental.list_physical_devices('GPU')
    with tf.device('/gpu:1'):
        network = create_model(args)
        network_state = tensor_spec.sample_spec_nest(
            network.state_spec, outer_dims=[1])
        ckpt = tf.train.Checkpoint(step=tf.Variable(9),model=network)
        # if tf.train.latest_checkpoint(args.loaded_checkpoints_dir):
        ckpt.restore(tf.train.latest_checkpoint(args.loaded_checkpoints_dir)).expect_partial()
        print("从 %s 恢复模型" % (tf.train.latest_checkpoint(args.loaded_checkpoints_dir)))
        return ckpt.model,network_state

def evaluate_checkpoint(workdir, config):
  """Evaluates the given checkpoint and writes results to workdir."""
  video_dir = os.path.join(workdir, "videos")
  if not tf.io.gfile.exists(video_dir):
    tf.io.gfile.makedirs(video_dir)
  rewards = {
      "blocktoblock":
          block2block.BlockToBlockReward,
      "blocktoabsolutelocation":
          block2absolutelocation.BlockToAbsoluteLocationReward,
      "blocktoblockrelativelocation":
          block2block_relative_location.BlockToBlockRelativeLocationReward,
      "blocktorelativelocation":
          block2relativelocation.BlockToRelativeLocationReward,
      "separate":
          separate_blocks.SeparateBlocksReward,
  }

  num_evals_per_reward = 50
  max_episode_steps = 200

  policy = None
  model,network_state = get_ckpt_model()

  results = collections.defaultdict(lambda: 0)
  for reward_name, reward_factory in rewards.items():
    env = language_table.LanguageTable(
        block_mode=blocks.LanguageTableBlockVariants.BLOCK_8,
        reward_factory=reward_factory,
        seed=0)
    env = gym_wrapper.GymWrapper(env)
    env = env_wrappers.UseTokenWrapper(env)
    env = env_wrappers.CentralCropImageWrapper(
        env,
        target_width=config.data_target_width,
        target_height=config.data_target_height,
        random_crop_factor=config.random_crop_factor)
    env = tfa_wrappers.HistoryWrapper(
        env, history_length=config.sequence_length, tile_first_step_obs=True)

    if policy is None:
      policy = jax_policy.BCJaxPyPolicy(
          env.time_step_spec(),
          env.action_spec(),
          model=model,
          network_state=network_state,
          rng=jax.random.PRNGKey(0))

    for ep_num in range(num_evals_per_reward):
      # Reset env. Choose new init if oracle cannot find valid motion plan.
      # Get an oracle. We use this at the moment to decide whether an
      # environment initialization is valid. If oracle can motion plan,
      # init is valid.
      oracle_policy = push_oracle_rrt_slowdown.ObstacleOrientedPushOracleBoard2dRRT(
          env, use_ee_planner=True)
      plan_success = False
      while not plan_success:
        ts = env.reset()
        raw_state = env.compute_state()
        plan_success = oracle_policy.get_plan(raw_state)
        if not plan_success:
          logging.info(
              "Resetting environment because the "
              "initialization was invalid (could not find motion plan).")

      frames = [env.render()]

      episode_steps = 0
      while not ts.is_last():
        policy_step = policy.action(ts, ())
        ts = env.step(policy_step.action)
        frames.append(env.render())
        episode_steps += 1

        if episode_steps > max_episode_steps:
          break

      success_str = ""
      if env.succeeded:
        results[reward_name] += 1
        logging.info("Episode %d: success.", ep_num)
        success_str = "success"
      else:
        logging.info("Episode %d: failure.", ep_num)
        success_str = "failure"

      # Write out video of rollout.
      video_path = os.path.join(workdir, "videos/",
                                f"{reward_name}_{ep_num}_{success_str}.mp4")

      imageio.mimsave(video_path, frames, fps=10)

    print(results)


def main(argv):
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  evaluate_checkpoint(
      workdir=_WORKDIR.value,
      config=_CONFIG.value,
  )


if __name__ == "__main__":
  
  app.run(main)




  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

过路张

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值