RT1复现(三)

RT1复现(三)

这部分的工作是对language部分进行特征提取,使用rt1文章的Universal Sentence Encoder的模型进行数据预处理
在这里插入图片描述先将language_table_sim-train.tfrecor-0000格式变成episode_0.npy格式,可以不用划分验证集和测试集
rlds_np_save.py

dataset_episode_num = 20200
builder = tfds.builder_from_directories(dataset_dirs)

ds = builder.as_dataset(
    split=f'train[{0}:{dataset_episode_num}]',
    decoders={"steps": {"observation": {"rgb": tfds.decode.SkipDecoding()}}},
    shuffle_files=False
)

from tqdm import tqdm
import os


# 将原本的数据存为np文件
def create_episode(path,raw_episode):
    episode = []
    for step in raw_episode[rlds.STEPS]:
        observation = step[rlds.OBSERVATION]
        observation_keys = observation.keys()
        step_keys = list(step.keys())
        step_keys.remove(rlds.OBSERVATION)
        step_dict = {}
        for k in step_keys:
            step_dict[k] = step[k].numpy()
        for k in observation_keys:
            step_dict[k] = observation[k].numpy()
        episode.append(step_dict)
    np.save(path, episode)

# 设置数据集中的episode数量

NUM_TRAIN = 20000
NUM_VAL = 100
NUM_TEST = 100

# 分别创建出对应的数据集

print("Generating train examples...")
os.makedirs('/mnt/ve_share2/zy/np_dataset/data/train', exist_ok=True)
cnt = 0
for element in tqdm(ds.take(NUM_TRAIN)):
    create_episode(f'/mnt/ve_share2/zy/np_dataset/data/train/episode_{cnt}.npy',element)
    cnt = cnt + 1

print("Generating val examples...")
os.makedirs('/mnt/ve_share2/zy/np_dataset/data/val', exist_ok=True)
cnt = 0
for element in tqdm(ds.skip(NUM_TRAIN).take(NUM_VAL)):
    create_episode(f'/mnt/ve_share2/zy/np_dataset/data/val/episode_{cnt}.npy', element)
    cnt = cnt + 1

print("Generating test examples...")
os.makedirs('/mnt/ve_share2/zy/np_dataset/data/test', exist_ok=True)
cnt = 0
for element in tqdm(ds.skip(NUM_TRAIN + NUM_VAL).take(NUM_TEST)):
    create_episode(f'/mnt/ve_share2/zy/np_dataset/data/test/episode_{cnt}.npy', element)
    cnt = cnt + 1


得到一下文件:
data
├── train
│ ├── episode_0
│ ├── episode_1
│ ├── …
├── val
│ ├── episode_100
│ ├── episode_101
│ ├── …
├── test
│ ├── episode_201
│ ├── episode_202
│ ├── …

然后将data移动到和language_table_use_dataset_builder.py 同目录下:
language_table_use
├── data
├── language_table_use_dataset_builder.py
├── …

使用一下命令

# create the environment
conda env create -f environment_ubuntu.yml
# activate the conda environment
conda activate rlds_env
# tfds build
tfds build --data_dir 'Fill in the data final storage path here'

language_table_use_dataset_builder.py

from typing import Iterator, Tuple, Any

import glob
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub


class LanguageTableUse(tfds.core.GeneratorBasedBuilder):
    """DatasetBuilder for example dataset."""

    VERSION = tfds.core.Version('1.0.0')
    RELEASE_NOTES = {
      '1.0.0': 'Initial release.',
    }

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._embed = hub.load("/mnt/ve_share2/zy/Universal_Sentence_Encoder")

    def _info(self) -> tfds.core.DatasetInfo:
        """Dataset metadata (homepage, citation,...)."""
        return self.dataset_info_from_configs(
            features=tfds.features.FeaturesDict({
                'steps': tfds.features.Dataset({
                    'observation': tfds.features.FeaturesDict({
                        'rgb': tfds.features.Image(
                            shape=(360, 640, 3),
                            dtype=np.uint8,
                            doc='RGB observation.',
                        ),
                        'effector_target_translation': tfds.features.Tensor(
                            shape=(2,),
                            dtype=np.float32,
                            doc='robot effector target,like x,y in the 2-D dimension',
                        ),
                        'effector_translation': tfds.features.Tensor(
                            shape=(2,),
                            dtype=np.float32,
                            doc='robot effector state,like x,y in the 2-D dimension',
                        ),
                        'instruction': tfds.features.Tensor(
                            shape=(512,),
                            dtype=np.float32,
                            doc='universial sentence embedding instruction',
                        ),
                    }),
                    'action': tfds.features.Tensor(
                        shape=(2,),
                        dtype=np.float32,
                        doc='Robot action',
                    ),
                    'reward': tfds.features.Scalar(
                        dtype=np.float32,
                        doc='Reward if provided, 1 on final step for demos.'
                    ),
                    'is_first': tfds.features.Scalar(
                        dtype=np.bool_,
                        doc='True on first step of the episode.'
                    ),
                    'is_last': tfds.features.Scalar(
                        dtype=np.bool_,
                        doc='True on last step of the episode.'
                    ),
                    'is_terminal': tfds.features.Scalar(
                        dtype=np.bool_,
                        doc='True on last step of the episode if it is a terminal step, True for demos.'
                    ),
                }),
                'episode_metadata': tfds.features.FeaturesDict({
                    'file_path': tfds.features.Text(
                        doc='Path to the original data file.'
                    ),
                }),
            }))

    def _split_generators(self, dl_manager: tfds.download.DownloadManager):
        """Define data splits."""
        return {
            'train': self._generate_examples(path='data/train/episode_*.npy'),
            'val': self._generate_examples(path='data/val/episode_*.npy'),
	    'test': self._generate_examples(path='data/test/episode_*.npy'),
        }

    def _generate_examples(self, path) -> Iterator[Tuple[str, Any]]:
        """Generator of examples for each split."""

        def _parse_example(episode_path):
            # load raw data --> this should change for your dataset
            with open(episode_path,'rb') as file:
                data = np.load(file, allow_pickle=True)     # this is a list of dicts in our case

            def decode_inst(inst):
                return bytes(inst[np.where(inst != 0)].tolist()).decode("utf-8")

            # assemble episode --> here we're assuming demos so we set reward to 1 at the end
            episode = []
            for i, step in enumerate(data):
                # compute Kona language embedding
                language_embedding = self._embed([decode_inst(np.array(step['instruction']))])[0].numpy()


                episode.append({
                    'observation': {
                        'rgb': step['rgb'],
                        'effector_target_translation': step['effector_target_translation'],
                        'effector_translation': step['effector_translation'],
                        'instruction': language_embedding,
                    },
                    'action': step['action'],
                    'reward': step['reward'],
                    'is_first': step['is_first'],
                    'is_last': step['is_last'],
                    'is_terminal': step['is_terminal'],
                })

            # create output data sample
            sample = {
                'steps': episode,
                'episode_metadata': {
                    'file_path': episode_path
                }
            }

            # if you want to skip an example for whatever reason, simply return None
            return episode_path, sample

        # create list of all examples
        episode_paths = glob.glob(path)

        # for smallish datasets, use single-thread parsing
        for sample in episode_paths:
            yield _parse_example(sample)

        # for large datasets use beam to parallelize data parsing (this will have initialization overhead)
        # beam = tfds.core.lazy_imports.apache_beam
        # return (
        #         beam.Create(episode_paths)
        #         | beam.Map(_parse_example)
        # )


environment_ubuntu.yml

name: rlds_env
channels:
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - ca-certificates=2023.7.22=hbcca054_0
  - ld_impl_linux-64=2.40=h41732ed_0
  - libffi=3.3=h58526e2_2
  - libgcc-ng=13.1.0=he5830b7_0
  - libgomp=13.1.0=he5830b7_0
  - libsqlite=3.42.0=h2797004_0
  - libstdcxx-ng=13.1.0=hfd8a6a1_0
  - libzlib=1.2.13=hd590300_5
  - ncurses=6.4=hcb278e6_0
  - openssl=1.1.1u=hd590300_0
  - pip=23.2.1=pyhd8ed1ab_0
  - python=3.9.0=hffdb5ce_5_cpython
  - readline=8.2=h8228510_1
  - setuptools=68.0.0=pyhd8ed1ab_0
  - sqlite=3.42.0=h2c6b66d_0
  - tk=8.6.12=h27826a3_0
  - tzdata=2023c=h71feb2d_0
  - wheel=0.41.0=pyhd8ed1ab_0
  - xz=5.2.6=h166bdaf_0
  - zlib=1.2.13=hd590300_5
  - pip:
      - absl-py==1.4.0
      - anyio==3.7.1
      - apache-beam==2.49.0
      - appdirs==1.4.4
      - array-record==0.4.0
      - astunparse==1.6.3
      - cachetools==5.3.1
      - certifi==2023.7.22
      - charset-normalizer==3.2.0
      - click==8.1.6
      - cloudpickle==2.2.1
      - contourpy==1.1.0
      - crcmod==1.7
      - cycler==0.11.0
      - dill==0.3.1.1
      - dm-tree==0.1.8
      - dnspython==2.4.0
      - docker-pycreds==0.4.0
      - docopt==0.6.2
      - etils==1.3.0
      - exceptiongroup==1.1.2
      - fastavro==1.8.2
      - fasteners==0.18
      - flatbuffers==23.5.26
      - fonttools==4.41.1
      - gast==0.4.0
      - gitdb==4.0.10
      - gitpython==3.1.32
      - google-auth==2.22.0
      - google-auth-oauthlib==1.0.0
      - google-pasta==0.2.0
      - googleapis-common-protos==1.59.1
      - grpcio==1.56.2
      - h11==0.14.0
      - h5py==3.9.0
      - hdfs==2.7.0
      - httpcore==0.17.3
      - httplib2==0.22.0
      - idna==3.4
      - importlib-metadata==6.8.0
      - importlib-resources==6.0.0
      - keras==2.13.1
      - kiwisolver==1.4.4
      - libclang==16.0.6
      - markdown==3.4.3
      - markupsafe==2.1.3
      - matplotlib==3.7.2
      - numpy==1.24.3
      - oauthlib==3.2.2
      - objsize==0.6.1
      - opt-einsum==3.3.0
      - orjson==3.9.2
      - packaging==23.1
      - pathtools==0.1.2
      - pillow==10.0.0
      - plotly==5.15.0
      - promise==2.3
      - proto-plus==1.22.3
      - protobuf==4.23.4
      - psutil==5.9.5
      - pyarrow==11.0.0
      - pyasn1==0.5.0
      - pyasn1-modules==0.3.0
      - pydot==1.4.2
      - pymongo==4.4.1
      - pyparsing==3.0.9
      - python-dateutil==2.8.2
      - pytz==2023.3
      - pyyaml==6.0.1
      - regex==2023.6.3
      - requests==2.31.0
      - requests-oauthlib==1.3.1
      - rsa==4.9
      - sentry-sdk==1.28.1
      - setproctitle==1.3.2
      - six==1.16.0
      - smmap==5.0.0
      - sniffio==1.3.0
      - tenacity==8.2.2
      - tensorboard==2.13.0
      - tensorboard-data-server==0.7.1
      - tensorflow==2.13.0
      - tensorflow-datasets==4.9.2
      - tensorflow-estimator==2.13.0
      - tensorflow-hub==0.14.0
      - tensorflow-io-gcs-filesystem==0.32.0
      - tensorflow-metadata==1.13.1
      - termcolor==2.3.0
      - toml==0.10.2
      - tqdm==4.65.0
      - typing-extensions==4.5.0
      - urllib3==1.26.16
      - wandb==0.15.6
      - werkzeug==2.3.6
      - wrapt==1.15.0
      - zipp==3.16.2
      - zstandard==0.21.0
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

过路张

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值