Dexcap复现代码模型训练全流程(二)——train.py增补:模型适配数据格式

此篇和上一篇(一)紧密相关!

在 robomimic 框架中,模型如何知道数据的格式并选择合适的数据进行训练,是通过以下几个步骤完成的,具体包括:

数据格式的解析 -> 元数据提取 -> 模型和训练过程的自动化配置

  • 配置文件:数据格式由配置文件定义,包括数据路径、类型(hdf5)和格式(state 等)->

  • 元数据提取:从数据集中提取特征和动作的形状,用于构建模型和数据加载器->

  • 动态加载与适配:数据加载器会根据数据格式加载和处理数据,模型会根据数据的输入维度动态调整网络结构->

  • 归一化处理:数据归一化确保输入特征和动作的数值范围适合训练

通过这些步骤,robomimic 可以自动适配多种数据格式

目录

1 数据格式的配置与解析

1.1 配置文件(config)

1.2 数据格式的动态加载

2 元数据的提取

3 数据加载器的动态生成

4 模型与数据格式的自动适配

5 数据预处理与归一化


1 数据格式的配置与解析

1.1 配置文件(config)

配置文件是整个训练过程的核心,定义了数据的路径和格式,以及模型如何处理这些数据

train.py 中,配置文件通过以下方式被加载并用于解析数据:

    if args.config is not None:
        # 从配置文件加载实验配置
        ext_cfg = json.load(open(args.config, 'r'))
        config = config_factory(ext_cfg["algo_name"])
        with config.values_unlocked():
            config.update(ext_cfg)
    else:
        # 动态创建默认配置
        config = config_factory(args.algo)

config.train 中,定义了训练数据的路径和格式等参数

    "train": {
        "data": [
            {
                "path": "[DATASET_ROOT_PATH]/hand_wiping_1-14_5actiongap_10000points.hdf5"
            }
        ],
        "output_dir": "trained_models",
        "num_data_workers": 2,
        "hdf5_cache_mode": null,
        "hdf5_use_swmr": true,
        "hdf5_load_next_obs": false,
        "hdf5_normalize_obs": false,
        "hdf5_filter_key": null,
        "hdf5_validation_filter_key": null,
        "seq_length": 20,
        "pad_seq_length": true,
        "frame_stack": 1,
        "pad_frame_stack": true,
        "dataset_keys": [
            "actions",
            "rewards",
            "dones"
        ],
        "action_keys": [
            "actions"
        ],
        "action_config": {
            "actions": {
                "normalization": "min_max"
            },
            "action_dict/abs_pos": {
                "normalization": "min_max"
            },
            "action_dict/abs_rot_axis_angle": {
                "normalization": "min_max",
                "format": "rot_axis_angle"
            },
            "action_dict/abs_rot_6d": {
                "normalization": null,
                "format": "rot_6d"
            },
            "action_dict/rel_pos": {
                "normalization": null
            },
            "action_dict/rel_rot_axis_angle": {
                "normalization": null,
                "format": "rot_axis_angle"
            },
            "action_dict/rel_rot_6d": {
                "normalization": null,
                "format": "rot_6d"
            },
            "action_dict/gripper": {
                "normalization": null
            }
        },
        "goal_mode": null,
        "cuda": true,
        "batch_size": 16,
        "num_epochs": 3000,
        "seed": 1,
        "data_format": "robomimic"
    }

data_format: 数据格式为 robomimic

1.2 数据格式的动态加载

配置文件会告知程序数据的存储路径和格式,并利用 robomimic 提供的工具自动解析

    # 加载训练数据和验证数据
    trainset, validset = TrainUtils.load_data_for_training(
        config, obs_keys=shape_meta["all_obs_keys"])
    train_sampler = trainset.get_dataset_sampler()

TrainUtils.load_data_for_training 方法的内部逻辑会根据配置文件的 data_format 和 data.path 自动加载相应的数据集

数据集的类型为 hdf5 文件,是一种通用的存储格式,包含训练数据、元数据和动作

2 元数据的提取

在训练之前,模型会提取数据集的元信息,包括输入的特征形状(observation shapes)和动作空间的维度,这些元信息用于初始化模型和数据加载器

    # 获取数据的形状元数据(如观测维度)
    shape_meta = FileUtils.get_shape_metadata_from_dataset(
        dataset_path=dataset_path,
        action_keys=config.train.action_keys,
        all_obs_keys=config.all_obs_keys,
        ds_format=ds_format,
        verbose=True
    )

FileUtils.get_shape_metadata_from_dataset:从数据集中提取特征和动作的形状(如图像分辨率、状态维度等)

根据 action_keys 和 obs_keys 确定哪些部分作为输入

action_keys:

        "action_keys": [
            "actions"

obs_keys:

    "observation": {
        "modalities": {
            "obs": {
                "low_dim": [
                    "robot0_eef_pos",
                    "robot0_eef_quat",
                    "robot0_eef_hand",
                    "pointcloud"
                ]

3 数据加载器的动态生成

在 load_data_for_training 中,robomimic 根据数据集的格式生成合适的数据加载器:

    # 初始化数据加载器(训练和验证)
    train_loader = DataLoader(
        dataset=trainset,
        sampler=train_sampler,
        batch_size=config.train.batch_size,
        shuffle=(train_sampler is None),
        num_workers=config.train.num_data_workers,
        drop_last=True
    )

支持的格式:

  • 如果是序列数据(sequence),数据加载器会按照时间序列切分

  • 如果是图像数据(image),会自动加载图像并进行预处理

  • 如果是状态数据(state),直接使用状态作为模型输入

数据集抽象:

  • 数据集通常实现 __getitem__ 和 __len__ 方法,以兼容 PyTorch 的 DataLoader

4 模型与数据格式的自动适配

在模型初始化时,robomimic 的算法工厂会根据 shape_meta 动态构建模型:

    # 使用算法工厂初始化模型
    model = algo_factory(
        algo_name=config.algo_name,
        config=config,
        obs_key_shapes=shape_meta["all_shapes"],
        ac_dim=shape_meta["ac_dim"],
        device=device,
    )

模型输入:

  • obs_key_shapes:告知模型每种输入特征的形状

  • ac_dim:动作空间的维度

动态适配:

  • 模型根据输入形状自动调整其前向网络,例如卷积层或全连接层的大小

5 数据预处理与归一化

训练前会对数据进行预处理,包括归一化操作:

    # maybe retreve statistics for normalizing observations
    obs_normalization_stats = None
    if config.train.hdf5_normalize_obs:
        obs_normalization_stats = trainset.get_obs_normalization_stats()

    # maybe retreve statistics for normalizing actions
    action_normalization_stats = trainset.get_action_normalization_stats()

观测值(observation)和动作(action)会根据数据统计进行归一化,提升训练的稳定性

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值