RT1复现(四)
这部分工作是测试部分,language-table测试有5个任务,测试过程随机生成颜色形状块,然后生成相对应的指令,每个任务有对应的成功标准,比如多少m之内就算成功
Language-Table中定义了五个模拟评估任务系列(涵盖了696个独特的任务条件),每个任务系列都有一个手动定义的成功标准:
-
block2block: 推一个块到另一个块。成功是源块和目标块之间的阈值距离。有56个唯一的任务条件(8个源块x7个目标块)。
-
block2abs: 将一个块推到板上的绝对位置:左上角、中上、右上角、左中、中心、右中、左下角、中下、右下角。成功是块和目标位置之间的阈值距离。有72个唯一的任务条件(8个块x9个位置)。
-
block2rel: 将一个块推到相对偏移的位置:左、右、上、下、左上、右上、左下、右下。成功是块和不可见目标偏移位置之间的阈值距离。有64个唯一的任务条件(8个块x8个偏移方向)。
-
block2blockrel: 将一个块推到另一个块的相对偏移位置:左侧、右侧、顶部、底部、左上侧、右上侧、左下侧、右下侧。成功是源块和目标块的不可见目标偏移位置之间的阈值距离。有448个唯一的任务条件(8个源块x7个目标块x8个偏移方向)。
-
separate: 将两个块分开。成功是两个块之间的阈值距离。有56个唯一的任务条件(8个源块x7个目标块)。
下载language-table项目
git clone https://github.com/google-research/language-table.git
然后需要修改eval/weappers.py
# coding=utf-8
# Copyright 2023 The Language Tale Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Environment wrappers."""
from typing import Any, Optional
from language_table.common import rt1_tokenizer
import numpy as np
import tensorflow as tf
from tf_agents.environments import wrappers
class UseTokenWrapper(wrappers.PyEnvironmentBaseWrapper):
"""Environment wrapper that adds CLIP tokens to the obs."""
def __init__(self, env, context_length = 77):
"""Centrally crops an image from a dict observation."""
super(UseTokenWrapper, self).__init__(env)
self._context_length = context_length
# vocab_lookup = clip_tokenizer.create_vocab()
# self._tokenizer = clip_tokenizer.ClipTokenizer(vocab_lookup)
self._current_tokens = None
def _reset(self):
time_step = self._env.reset()
self._current_tokens = self._tokenize(time_step.observation['instruction'])
new_obs = time_step.observation
new_obs['instruction_tokenized_use'] = self._current_tokens
return time_step._replace(observation=new_obs)
def _step(self, action):
time_step = self._env.step(action)
new_obs = time_step.observation
new_obs['instruction_tokenized_use'] = tf.convert_to_tensor(self._current_tokens)
return time_step._replace(observation=new_obs)
def _tokenize(self, instruction):
bytes_list = instruction
non_zero = bytes_list[np.where(bytes_list != 0)]
if non_zero.shape[0] == 0:
decoded = ''
else:
bytes_list = bytes(non_zero.tolist())
decoded = bytes_list.decode('utf-8')
tokens = rt1_tokenizer.tokenize_text(decoded)
return tokens
class CentralCropImageWrapper(wrappers.PyEnvironmentBaseWrapper):
"""Environment wrapper that crops image observations."""
def __init__(self,
env,
target_height,
target_width,
random_crop_factor = None):
"""Centrally crops an image from a dict observation."""
super(CentralCropImageWrapper, self).__init__(env)
self._target_height = target_height
self._target_width = target_width
self._random_crop_factor = random_crop_factor
def _reset(self):
time_step = self._env.reset()
new_obs = self._crop_observation(time_step.observation)
return time_step._replace(observation=new_obs)
def _step(self, action):
time_step = self._env.step(action)
new_obs = self._crop_observation(time_step.observation)
return time_step._replace(observation=new_obs)
def _crop_observation(self, obs):
new_obs = obs
image = obs['rgb']
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Apply average crop augmentation.
image = crop_test_image(image, self._random_crop_factor)
image = resize_images(image, self._target_height, self._target_width)
new_obs['rgb_sequence'] = tf.convert_to_tensor(image)
return new_obs
def crop_test_image(images, random_crop_factor):
"""Get the average crop applied during crop training augmentation."""
def take_center_crop_consistent_with_random(im):
im_raw_size = tf.shape(im)
raw_height = tf.cast(im_raw_size[0], tf.float32)
raw_width = tf.cast(im_raw_size[1], tf.float32)
scaled_height = raw_height * random_crop_factor
scaled_width = raw_width * random_crop_factor
offset_height = tf.cast((raw_height - scaled_height) // 2, tf.int32)
offset_width = tf.cast((raw_width - scaled_width) // 2, tf.int32)
target_height = tf.cast(scaled_height, tf.int32)
target_width = tf.cast(scaled_width, tf.int32)
im = tf.image.crop_to_bounding_box(
im,
offset_height=offset_height,
offset_width=offset_width,
target_height=target_height,
target_width=target_width)
return im
if len(images.shape) == 3:
return take_center_crop_consistent_with_random(images)
images = tf.map_fn(take_center_crop_consistent_with_random, images)
return images
def resize_images(images, target_height=None, target_width=None):
"""Resizes images to target_height, target_width."""
assert target_height
assert target_width
# Resize to target height and width.
def _resize(im):
return tf.image.resize(im, [target_height, target_width])
images = _resize(images)
return images
rt1_tokenizer.py
import tensorflow_hub as hub
def tokenize_text(text):
"""Tokenizes the input text given a tokenizer."""
embed = hub.load("/mnt/ve_share2/zy/Universal_Sentence_Encoder")
tokens = embed([text])
return tokens
# text ='push the blue triangle closer to yellow heart'
# print(tokenize_text(text))
eval.py (参考eval/main.py进行修改)
import imageio
import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__)) # 当前文件所在目录
parent_dir = os.path.dirname(current_dir) # 上一级目录
# parent_dir = os.path.dirname(parent_dir1)
sys.path.append(parent_dir)
from distribute_train import get_args,create_model,create_train_dataset
from tf_agents.specs import tensor_spec
import tensorflow as tf
import collections
from collections.abc import Sequence
import os
from absl import app
from absl import flags
from absl import logging
import jax
import numpy as np
# from language_table.common import rt1_tokenizer
from language_table.environments import blocks
from language_table.environments import language_table
from language_table.environments.oracles import push_oracle_rrt_slowdown
from language_table.environments.rewards import block2absolutelocation
from language_table.environments.rewards import block2block
from language_table.environments.rewards import block2block_relative_location
from language_table.environments.rewards import block2relativelocation
from language_table.environments.rewards import separate_blocks
from language_table.eval import wrappers as env_wrappers
from language_table.train import policy as jax_policy
from ml_collections import config_flags
import tensorflow as tf
import tensorflow_hub as hub
from tf_agents.environments import gym_wrapper
from tf_agents.environments import wrappers as tfa_wrappers
_CONFIG = config_flags.DEFINE_config_file(
"config","/mnt/ve_share2/zy/robotics_transformer/language_table/train/configs/language_table_sim_local.py", "Training configuration.", lock_config=True)
_WORKDIR = flags.DEFINE_string("workdir","/mnt/ve_share2/zy/eval", "working dir")
def get_ckpt_model():
time_sequence_length = 6 # 常量,来自论文每次预测使用6张图片
args = get_args()
gpus = tf.config.experimental.list_physical_devices('GPU')
with tf.device('/gpu:1'):
network = create_model(args)
network_state = tensor_spec.sample_spec_nest(
network.state_spec, outer_dims=[1])
ckpt = tf.train.Checkpoint(step=tf.Variable(9),model=network)
# if tf.train.latest_checkpoint(args.loaded_checkpoints_dir):
ckpt.restore(tf.train.latest_checkpoint(args.loaded_checkpoints_dir)).expect_partial()
print("从 %s 恢复模型" % (tf.train.latest_checkpoint(args.loaded_checkpoints_dir)))
return ckpt.model,network_state
def evaluate_checkpoint(workdir, config):
"""Evaluates the given checkpoint and writes results to workdir."""
video_dir = os.path.join(workdir, "videos")
if not tf.io.gfile.exists(video_dir):
tf.io.gfile.makedirs(video_dir)
rewards = {
"blocktoblock":
block2block.BlockToBlockReward,
"blocktoabsolutelocation":
block2absolutelocation.BlockToAbsoluteLocationReward,
"blocktoblockrelativelocation":
block2block_relative_location.BlockToBlockRelativeLocationReward,
"blocktorelativelocation":
block2relativelocation.BlockToRelativeLocationReward,
"separate":
separate_blocks.SeparateBlocksReward,
}
num_evals_per_reward = 50
max_episode_steps = 200
policy = None
model,network_state = get_ckpt_model()
results = collections.defaultdict(lambda: 0)
for reward_name, reward_factory in rewards.items():
env = language_table.LanguageTable(
block_mode=blocks.LanguageTableBlockVariants.BLOCK_8,
reward_factory=reward_factory,
seed=0)
env = gym_wrapper.GymWrapper(env)
env = env_wrappers.UseTokenWrapper(env)
env = env_wrappers.CentralCropImageWrapper(
env,
target_width=config.data_target_width,
target_height=config.data_target_height,
random_crop_factor=config.random_crop_factor)
env = tfa_wrappers.HistoryWrapper(
env, history_length=config.sequence_length, tile_first_step_obs=True)
if policy is None:
policy = jax_policy.BCJaxPyPolicy(
env.time_step_spec(),
env.action_spec(),
model=model,
network_state=network_state,
rng=jax.random.PRNGKey(0))
for ep_num in range(num_evals_per_reward):
# Reset env. Choose new init if oracle cannot find valid motion plan.
# Get an oracle. We use this at the moment to decide whether an
# environment initialization is valid. If oracle can motion plan,
# init is valid.
oracle_policy = push_oracle_rrt_slowdown.ObstacleOrientedPushOracleBoard2dRRT(
env, use_ee_planner=True)
plan_success = False
while not plan_success:
ts = env.reset()
raw_state = env.compute_state()
plan_success = oracle_policy.get_plan(raw_state)
if not plan_success:
logging.info(
"Resetting environment because the "
"initialization was invalid (could not find motion plan).")
frames = [env.render()]
episode_steps = 0
while not ts.is_last():
policy_step = policy.action(ts, ())
ts = env.step(policy_step.action)
frames.append(env.render())
episode_steps += 1
if episode_steps > max_episode_steps:
break
success_str = ""
if env.succeeded:
results[reward_name] += 1
logging.info("Episode %d: success.", ep_num)
success_str = "success"
else:
logging.info("Episode %d: failure.", ep_num)
success_str = "failure"
# Write out video of rollout.
video_path = os.path.join(workdir, "videos/",
f"{reward_name}_{ep_num}_{success_str}.mp4")
imageio.mimsave(video_path, frames, fps=10)
print(results)
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
evaluate_checkpoint(
workdir=_WORKDIR.value,
config=_CONFIG.value,
)
if __name__ == "__main__":
app.run(main)