RT1复现(二)

RT1复现(二)

RT1目前只公开了模型的代码,缺少dataloader和train部分的代码
,而且也没有公开测试部分的代码和仿真环境,我们采用language_table数据集进行RT1的复现,因为该数据集开源且有仿真环境,可以对RT1模型输出的action进行一个测试

dataloader

"""读取RLDS数据集,详见:https://github.com/google-research/rlds  数据读取代码参考https://github.com/google-research/language-table"""

import dataclasses
import functools
from typing import Optional, Tuple
from clu import preprocess_spec
import jax
import rlds
import tensorflow as tf
import tensorflow_datasets as tfds
import tree

Features = preprocess_spec.Features


def create_datasets(
        rng,
        dataset_dirs,
        sequence_length,
        global_batch_size,
        target_width=320,
        target_height=180,
        random_crop_factor=None,
        cache=False,
        shuffle=True,
        shuffle_buffer_size=50_000,
        cache_dir=None,
        dataset_episode_num=10000
):
    """创建一个RLDS数据集."""

    builder = tfds.builder_from_directories(dataset_dirs)

    dataset_options = tf.data.Options()
    dataset_options.experimental_optimization.map_parallelization = True
    dataset_options.threading.private_threadpool_size = 48
    dataset_options.threading.max_intra_op_parallelism = 1

    ds = builder.as_dataset(
        split=f'train[{0}:{dataset_episode_num}]',
        decoders={"steps": {"observation": {"rgb": tfds.decode.SkipDecoding()}}},
        shuffle_files=True
    )
    
    def _pad_episode(episode, padding):
        first_item_tensor = episode["steps"].take(1).get_single_element()
        first_item_ds = tf.data.Dataset.from_tensors(first_item_tensor)

        first_item_mid_tensor = tf.nest.map_structure(
            tf.identity, first_item_tensor
        )
        first_item_mid_tensor[rlds.IS_FIRST] = False
        padding_ds = tf.data.Dataset.from_tensors(first_item_mid_tensor).repeat(
            padding
        )

        full_padding = rlds.transformations.concatenate(first_item_ds, padding_ds)
        episode["steps"] = rlds.transformations.concatenate(
            full_padding, episode["steps"].skip(1)
        )
        return episode

    ds = ds.map(
        functools.partial(_pad_episode, padding=sequence_length - 1),
        tf.data.AUTOTUNE,
    )

    def get_seqlen_pattern(step):
        return {
            rlds.OBSERVATION: tree.map_structure(
                lambda x: x[-sequence_length:], step[rlds.OBSERVATION]
            ),
            rlds.ACTION: tree.map_structure(
                lambda x: x[-sequence_length:], step[rlds.ACTION]
            ),
            rlds.IS_TERMINAL: tree.map_structure(
                lambda x: x[-sequence_length:], step[rlds.IS_TERMINAL]
            ),
        }

    ds = rlds.transformations.pattern_map_from_transform(
        episodes_dataset=ds,
        transform_fn=get_seqlen_pattern,
        respect_episode_boundaries=True,
    )

    if shuffle:
        shuffle_rng, rng = jax.random.split(rng)
        shuffle_rng = shuffle_rng[0]
        ds = ds.shuffle(shuffle_buffer_size, shuffle_rng)

    preprocessors = [
        DecodeAndRandomResizedCrop(
            random_crop_factor=random_crop_factor,
            resize_size=(target_height, target_width),
        ),
        TransformDict(),
    ]
    train_preprocess = preprocess_spec.PreprocessFn(
        preprocessors, only_jax_types=True
    )

    def _preprocess_fn(example_index, features):
        example_index = tf.cast(example_index, tf.int32)
        features[preprocess_spec.SEED_KEY] = (
            tf.random.experimental.stateless_fold_in(
                tf.cast(rng, tf.int64), example_index
            )
        )
        processed = train_preprocess(features)
        return processed

    ds = ds.enumerate().map(_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)

    ds = ds.batch(global_batch_size, drop_remainder=True)

    if cache:
        ds = ds.cache(cache_dir)

    return ds


@dataclasses.dataclass(frozen=True)
class DecodeAndRandomResizedCrop(preprocess_spec.RandomMapTransform):
    """解析图像,提取随机crop, resize并归一化"""

    random_crop_factor: Optional[float] = None
    resize_size: Tuple[int, int] = (180, 320)

    def _transform(self, features, seed):
        image = features["observation"]["rgb"]
        shape = tf.io.extract_jpeg_shape(image[0])
        raw_height, raw_width = shape[0], shape[1]
        raw_height = tf.cast(raw_height, tf.float32)
        raw_width = tf.cast(raw_width, tf.float32)

        if self.random_crop_factor is None:
            random_crop_factor = 1.0
            offset_width = 0
            offset_height = 0
            scaled_height = raw_height
            scaled_width = raw_width
        else:
            random_crop_factor = tf.constant(
                self.random_crop_factor, dtype=tf.float32
            )
            scaled_height = raw_height * random_crop_factor
            scaled_width = raw_width * random_crop_factor

            next_rng, rng = tf.unstack(tf.random.experimental.stateless_split(seed))
            offset_height = tf.random.stateless_uniform(
                shape=(),
                seed=next_rng,
                minval=0,
                maxval=tf.cast(raw_height - scaled_height, dtype=tf.int32),
                dtype=tf.int32,
            )

            next_rng, rng = tf.unstack(tf.random.experimental.stateless_split(rng))
            offset_width = tf.random.stateless_uniform(
                shape=(),
                seed=next_rng,
                minval=0,
                maxval=tf.cast(raw_width - scaled_width, dtype=tf.int32),
                dtype=tf.int32,
            )

        def apply_decode_and_crop(image):
            image = tf.image.decode_and_crop_jpeg(
                image,
                [
                    offset_height,
                    offset_width,
                    tf.cast(scaled_height, tf.int32),
                    tf.cast(scaled_width, tf.int32),
                ],
                channels=3,
            )
            return image

        image = tf.map_fn(apply_decode_and_crop, image, dtype=tf.uint8)
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.image.resize(image, self.resize_size)
        features["observation"]["rgb"] = image
        return features


@dataclasses.dataclass(frozen=True)
class TransformDict(preprocess_spec.RandomMapTransform):
    """将数据存放字典格式转换成网络所需数据字典格式."""

    def _transform(self, features, seed):
        """Applies all distortions."""
        action_lable = {
            "terminate_episode": tf.one_hot(tf.cast(features["is_terminal"], dtype=tf.int32), depth=2, dtype=tf.int32),
            "action": features["action"]}
        train_observation = {"image": features["observation"]["rgb"],
                             "natural_language_embedding": features['observation']['instruction']}
        features = {"action_lable": action_lable, "train_observation": train_observation}
        return features

train

"""
robotic transformer(https://github.com/google-research/robotics_transformer)的多节点分布式训练代码,
采用tensorflow2的distribute.MultiWorkerMirroredStrategy(https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy)进行分布式训练,使用加载rlds(https://github.com/google-research/rlds)数据的方式进行数据的读取
使用方法:
    python distribute_worker_train.py --args = param, 其中args见代码中的get_args()
"""

import os
import sys
from datetime import datetime

current_dir = os.path.dirname(os.path.abspath(__file__))  # 当前文件所在目录
parent_dir = os.path.dirname(current_dir) # 上一级目录

sys.path.append(parent_dir) 
from robotics_transformer import transformer_network
from tensor2robot.utils import tensorspec_utils
from tf_agents.specs import tensor_spec
import time
from robotics_transformer import rlds_dataset_loader
# from data_loader import rlds_dataset_loader
import tensorflow as tf
import jax
import argparse
import json


def get_args():
    parser = argparse.ArgumentParser(description='获得分布式训练参数')
    parser.add_argument('--single_gpu_batch_size', '-s', help='batch size for single gpu', default=15, type=int)
    parser.add_argument('--training_epoch', '-te', help='training epoch', default=11, type=int)  # 训练epoch
    parser.add_argument('--log_step', '-ls', help='log step', default=10, type=int)
    parser.add_argument('--dataset_dirs', '-d', help='dataset path', default="/mnt/ve_share2/zy/language_table_sim_use_20000/language_table_use/1.0.0")
    parser.add_argument('--learning_rate', '-lr', help='learning rate', default=0.0005, type=float)  # 学习率  
    parser.add_argument('--vocab_size', '-vs', help='vocab size for discretization', default=256, type=int)  # 离散词典大小
    parser.add_argument('--dataset_episode_num', '-den', help='训练数据量', default=10000, type=int)
    parser.add_argument('--loaded_checkpoints_dir', '-lcd', help='模型加载目录', default="/mnt/ve_share2/zy/save_checkpoint", type=str)
    parser.add_argument('--save_model', '-sm', help='save model', default=True)
    parser.add_argument('--model_save_epoch', '-mse', help='save model at every num epoch', default=1, type=int)
    parser.add_argument('--checkpoints_saved_dir', '-csd', help='模型保存目录', default="/mnt/ve_share2/zy/robotics_transformer/save_checkpoint", type=str)
    args = parser.parse_args()
    return args


time_sequence_length = 6  # 常量,来自论文


def create_train_dataset(args, global_batch_size):
    '''创建数据集'''
    dataset_dirs = args.dataset_dirs.split("+")

    workdir = "./"
    sequence_length = time_sequence_length
    data_target_width = 456
    data_target_height = 256
    random_crop_factor = 0.95
    replay_capacity = 5_000
    seed = 42
    rng = jax.random.PRNGKey(seed)
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())

    train_ds = rlds_dataset_loader.create_datasets(
        data_rng,
        dataset_dirs=dataset_dirs,
        sequence_length=sequence_length,
        global_batch_size=global_batch_size,
        target_width=data_target_width,
        target_height=data_target_height,
        random_crop_factor=random_crop_factor,
        cache=False,
        shuffle=True,
        shuffle_buffer_size=replay_capacity,
        cache_dir=workdir,
        dataset_episode_num=args.dataset_episode_num
    )

    return train_ds


def create_model(args):
    '''创建模型'''
    data_target_width = 456
    data_target_height = 256

    state_spec = tensorspec_utils.TensorSpecStruct()

    state_spec.image = tensor_spec.BoundedTensorSpec([data_target_height, data_target_width, 3],
                                                     dtype=tf.float32,
                                                     name='image',
                                                     minimum=0.,
                                                     maximum=1.)
    state_spec.natural_language_embedding = tensor_spec.TensorSpec(
        shape=[512], dtype=tf.float32, name='natural_language_embedding')

    action_spec = tensorspec_utils.TensorSpecStruct()

    action_spec.terminate_episode = tensor_spec.BoundedTensorSpec(
        (2,), dtype=tf.int32, minimum=0, maximum=1, name='terminate_episode')

    action_spec.action = tensor_spec.BoundedTensorSpec(
        (2,), dtype=tf.float32, minimum=-0.03, maximum=0.03, name='action')

    network = transformer_network.TransformerNetwork(
        input_tensor_spec=state_spec,
        output_tensor_spec=action_spec,
        vocab_size=int(args.vocab_size),
        token_embedding_size=512,
        num_layers=8,
        layer_size=128,
        num_heads=8,
        feed_forward_size=512,
        dropout_rate=0.1,
        time_sequence_length=time_sequence_length,
        crop_size=236,
        use_token_learner=True,
        action_order=['terminate_episode', 'action'])
    return network


if __name__ == '__main__':
    os.environ.pop('TF_CONFIG', None)

    args = get_args()

    physical_devices = tf.config.experimental.list_physical_devices('GPU')

    if len(physical_devices) > 0:
        for k in range(len(physical_devices)):
            tf.config.experimental.set_memory_growth(physical_devices[k], True)
    else:
        print("GPU数量不够")
        exit("异常退出")
    
    mirrored_strategy = tf.distribute.MirroredStrategy()  
    global_batch_size = args.single_gpu_batch_size * mirrored_strategy.num_replicas_in_sync
    global_learning_rate = args.learning_rate * global_batch_size

    logdir = os.path.join(args.checkpoints_saved_dir, 'logs', datetime.now().strftime("%Y%m%d-%H%M%S"))
    summary_writer = tf.summary.create_file_writer(logdir)

    with mirrored_strategy.scope():
        network = create_model(args)
        network.create_variables()
        dataset_dirs = args.dataset_dirs.split("+")

        train_ds = create_train_dataset(args, global_batch_size)

        dist_dataset = mirrored_strategy.experimental_distribute_dataset(train_ds)
        network_state = tensor_spec.sample_spec_nest(
            network.state_spec, outer_dims=[args.single_gpu_batch_size])
        optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)

        ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, model=network)
        # if tf.train.latest_checkpoint(args.loaded_checkpoints_dir):
        # print(tf.train.latest_checkpoint(args.loaded_checkpoints_dir))
        # ckpt.restore(tf.train.latest_checkpoint(args.loaded_checkpoints_dir))
        
        # print("从 %s 恢复模型" % (args.loaded_checkpoints_dir))

        current_step = ckpt.step.numpy()
        print("开始训练")
        T1 = time.time()


        @tf.function
        def train_one_step(model, observation_batch, label_batch, network_state, optimizer):
            '''单步训练'''
            with tf.GradientTape() as tape:
                model.set_actions(label_batch)
                model(observation_batch, step_type=None, network_state=network_state, training=True)
                loss = tf.reduce_mean(model.get_actor_loss())
                gradients = tape.gradient(loss, model.trainable_variables,
                                            unconnected_gradients=tf.UnconnectedGradients.ZERO)
                optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))
                logging_info = model.get_aux_info()
                return loss, logging_info


        # action_order = network._action_order
        action_order =network._action_tokenizer.action_order
        
        for epoch in range(1, args.training_epoch):
            total_loss = 0.0
            step = 0
            T1 = time.time()

            for data in dist_dataset:
                train_observation = data["train_observation"]
                train_labels = data["action_lable"]
                per_replica_losses, logging_info = mirrored_strategy.run(
                    train_one_step, args=(network, train_observation, train_labels, network_state, optimizer))
                step = step + 1
                mean_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)   
                total_loss = total_loss + mean_loss
                ckpt.step.assign_add(1)
                # print('训练1个step耗时: %s s' % (10))
            print(f'Epoch: {epoch},Step:{step},Loss:{total_loss}')

            T2 = time.time()
            print('训练1个epoch 总耗时: ', ((T2 - T1)))
            with summary_writer.as_default():
                tf.summary.scalar('loss', total_loss, step=epoch)

            if epoch % args.model_save_epoch == 0 and args.save_model:
                checkpoint_prefix = os.path.join(args.checkpoints_saved_dir, "ckpt")
                ckpt.save(checkpoint_prefix)
                print("模型保存位置:  %s !" % (checkpoint_prefix))
        summary_writer.close()

print("正常退出!")

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

过路张

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值