habitat模型训练总结(一):点导航PPO

本文对habitat环境中的baseline点导航PPO的模型的构建和训练进行总结

0 训练代码

这个代码在上一篇文章出现过,再粘贴过来,如下:

import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torch

if __name__ == "__main__":
    run_type = "train"
    config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")

    config.defrost()
    config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
    config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"
    config.freeze()
   
    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)
    torch.manual_seed(config.TASK_CONFIG.SEED)
    if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
        torch.set_num_threads(1)

    trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)
    print('trainer_init:',trainer_init)
    assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported"
    trainer = trainer_init(config)

    if run_type == "train":
        trainer.train()
    elif run_type == "eval":
        trainer.eval()

1 trainer

在配置文件中要配置TRAINER_NAME,上面代码中的config.TRAINER_NAME是ppo,然后通过

trainer = trainer_init(config)

这句话,trainer就成了一个rl.ppo.ppo_trainer.PPOTrainer对象。在......./habitat_baselines/rl/ppo/ppo_trainer.py中定义

2 训练过程

我们看到训练过程调用了rl.ppo.ppo_trainer.PPOTrainer中的train()方法

  trainer.train()

TODO:对train()方法进行分析

3 模型结构定义

PPO是actor_critic结构,需要两个网络一个actor网络,一个critic网络。这两个网络可以共享参数也可以不共享参数。habitat中的ppo在特征提取阶段采用了参数共享,然后分出了两个头。

@baseline_registry.register_trainer(name="ddppo")
@baseline_registry.register_trainer(name="ppo")
class PPOTrainer(BaseRLTrainer):
    r"""Trainer class for PPO algorithm
    Paper: https://arxiv.org/abs/1707.06347.
    """
    supported_tasks = ["Nav-v0"]

    SHORT_ROLLOUT_THRESHOLD: float = 0.25
    _is_distributed: bool
    envs: VectorEnv
    agent: PPO
    actor_critic: NetPolicy

    def __init__(self, config=None):
        super().__init__(config)
        self.actor_critic = None
        self.agent = None
        self.envs = None
        self.obs_transforms = []

        self._static_encoder = False
        self._encoder = None
        self._obs_space = None

        # Distributed if the world size would be
        # greater than 1
        self._is_distributed = get_distrib_size()[2] > 1

上面的代码可以看出self.actor_critic是NetPolicy类型,而NetPolicy类在......./habitat_baselines/rl/ppo/policy.py中定义。在......./habitat_baselines/rl/ppo/ppo_trainer.py中的_setup_actor_critic_agent()方法中

        policy = baseline_registry.get_policy(self.config.RL.POLICY.name)
......
        self.actor_critic = policy.from_config(
            self.config,
            observation_space,
            self.policy_action_space,
            orig_action_space=self.orig_policy_action_space,
        )

定义了self.actor_critic,如上。
1、policy是哪个policy是由self.config.RL.POLICY.name来决定的(类名),举例如下:

  name: PointNavResNetPolicy

这个policy在......./habitat_baselines/rl/ddppo/policy/resnet_policy.py中进行了定义,其中也对网络结构进行了定义。
举例二:

  name: PointNavBaselinePolicy

这个policy直接在......./habitat_baselines/rl/ppo/policy.py中进行了定义

......./habitat_baselines/rl/ppo/policy.py文件中的NetPolicy类很关键,定义了很多关键的东西()。其他policy(比如上面谈到的PointNavResNetPolicy(其他文件中定义)和PointNavBaselinePolicy)都是继承自他。
NetPolicy类的初始化函数的第一个参数是net,这个net由继承他的类来定义

    def __init__(
        self, net, action_space, policy_config=None, aux_loss_config=None
    ):

比如:

@baseline_registry.register_policy
class PointNavBaselinePolicy(NetPolicy):
    def __init__(
        self,
        observation_space: spaces.Dict,
        action_space,
        hidden_size: int = 512,
        aux_loss_config=None,
        **kwargs,
    ):
        super().__init__(
            PointNavBaselineNet(  # type: ignore
                observation_space=observation_space,
                hidden_size=hidden_size,
                **kwargs,
            ),
            action_space=action_space,
            aux_loss_config=aux_loss_config,
        )

其中的super().init()的第一个参数

PointNavBaselineNet(  # type: ignore
                observation_space=observation_space,
                hidden_size=hidden_size,
                **kwargs,
            )

就通过父类(NetPolicy)的__init__()函数进行了初始化。而例子中的PointNavBaselineNet在自身文件中进行了定义(其他网络需要在其他文件中定义):

class PointNavBaselineNet(Net):
    r"""Network which passes the input image through CNN and concatenates
    goal vector with CNN's output and passes that through RNN.
    """

    def __init__(
        self,
        observation_space: spaces.Dict,
        hidden_size: int,
    ):
        super().__init__()

        if (
            IntegratedPointGoalGPSAndCompassSensor.cls_uuid
            in observation_space.spaces
        ):
            self._n_input_goal = observation_space.spaces[
                IntegratedPointGoalGPSAndCompassSensor.cls_uuid
            ].shape[0]
        elif PointGoalSensor.cls_uuid in observation_space.spaces:
            self._n_input_goal = observation_space.spaces[
                PointGoalSensor.cls_uuid
            ].shape[0]
        elif ImageGoalSensor.cls_uuid in observation_space.spaces:
            goal_observation_space = spaces.Dict(
                {"rgb": observation_space.spaces[ImageGoalSensor.cls_uuid]}
            )
            self.goal_visual_encoder = SimpleCNN(
                goal_observation_space, hidden_size
            )
            self._n_input_goal = hidden_size

        self._hidden_size = hidden_size

        self.visual_encoder = SimpleCNN(observation_space, hidden_size)

        self.state_encoder = build_rnn_state_encoder(
            (0 if self.is_blind else self._hidden_size) + self._n_input_goal,
            self._hidden_size,
        )

        self.train()

    @property
    def output_size(self):
        return self._hidden_size

    @property
    def is_blind(self):
        return self.visual_encoder.is_blind

    @property
    def num_recurrent_layers(self):
        return self.state_encoder.num_recurrent_layers

    @property
    def perception_embedding_size(self):
        return self._hidden_size

    def forward(
        self,
        observations,
        rnn_hidden_states,
        prev_actions,
        masks,
        rnn_build_seq_info: Optional[Dict[str, torch.Tensor]] = None,
    ):
        aux_loss_state = {}
        if IntegratedPointGoalGPSAndCompassSensor.cls_uuid in observations:
            target_encoding = observations[
                IntegratedPointGoalGPSAndCompassSensor.cls_uuid
            ]
        elif PointGoalSensor.cls_uuid in observations:
            target_encoding = observations[PointGoalSensor.cls_uuid]
        elif ImageGoalSensor.cls_uuid in observations:
            image_goal = observations[ImageGoalSensor.cls_uuid]
            target_encoding = self.goal_visual_encoder({"rgb": image_goal})

        x = [target_encoding]

        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            x = [perception_embed] + x
            aux_loss_state["perception_embed"] = perception_embed

        x_out = torch.cat(x, dim=1)
        x_out, rnn_hidden_states = self.state_encoder(
            x_out, rnn_hidden_states, masks, rnn_build_seq_info
        )
        aux_loss_state["rnn_output"] = x_out

        return x_out, rnn_hidden_states, aux_loss_state

NetPolicy类中最重要的是act方法,该方法定义了网络的整体结构。

    def act(
        self,
        observations,
        rnn_hidden_states,
        prev_actions,
        masks,
        deterministic=False,
        envs_to_pause=None,
        rtgs=None,
    ):
        if envs_to_pause is not None:
            features, rnn_hidden_states, _ = self.net(
                observations, rnn_hidden_states, prev_actions, masks, envs_to_pause=envs_to_pause, rtgs=rtgs,
            )
        else:
            features, rnn_hidden_states, _ = self.net(
                observations, rnn_hidden_states, prev_actions, masks, 
            )
        distribution = self.action_distribution(features)
        value = self.critic(features)
        # self.cur_skill = torch.argmax(self.aux_head(features), dim=-1)

        if deterministic:
            if self.action_distribution_type == "categorical":
                action = distribution.mode()
            elif self.action_distribution_type == "gaussian":
                action = distribution.mean
            elif self.action_distribution_type == "mixed":
                action = distribution.mean()
        else:
            action = distribution.sample()

        action_log_probs = distribution.log_probs(action)

        return value, action, action_log_probs, rnn_hidden_states

解析:
self.net()是actor网络和critic网络共享参数的部分。下面两句话使用了不同的分支。

        distribution = self.action_distribution(features)
        value = self.critic(features)

其中self.action_distribution是动作分支,self.critic是计算value的分支,这两个部分在NetPolicy类的初始化函数中定义:

    def __init__(
        self, net, action_space, policy_config=None, aux_loss_config=None
    ):
        super().__init__()
        self.net = net
        self.dim_actions = get_num_actions(action_space)
        self.action_distribution: Union[CategoricalNet, GaussianNet]

        if policy_config is None:
            self.action_distribution_type = "categorical"
        else:
            self.action_distribution_type = (
                policy_config.action_distribution_type
            )

        if self.action_distribution_type == "categorical":
            self.action_distribution = CategoricalNet(
                self.net.output_size, self.dim_actions
            )
        elif self.action_distribution_type == "gaussian":
            self.action_distribution = GaussianNet(
                self.net.output_size,
                self.dim_actions,
                policy_config.ACTION_DIST,
            )
        elif self.action_distribution_type == "mixed":
            self.action_distribution = MixedDistributionNet(
                self.net.output_size,
                policy_config.ACTION_DIST,
                action_space,
            )
        else:
            ValueError(
                f"Action distribution {self.action_distribution_type}"
                "not supported."
            )

        self.critic = CriticHead(self.net.output_size)

        self.aux_loss_modules = nn.ModuleDict()
        for aux_loss_name in (
            () if aux_loss_config is None else aux_loss_config.enabled
        ):
            aux_loss = baseline_registry.get_auxiliary_loss(aux_loss_name)

            self.aux_loss_modules[aux_loss_name] = aux_loss(
                action_space,
                self.net,
                **getattr(aux_loss_config, aux_loss_name),
            )

可以看到初始化函数主要对self.action_distributionself.criticself.aux_loss_modules进行了定义。

  • self.action_distribution的 定义需要考虑动作类别是离散还是连续还是混合。不同类别分别用到了 CategoricalNet、GaussianNet、MixedDistributionNet,而这些网络在habitat_baselines.utils.common
    中有定义。
from habitat_baselines.utils.common import (
    CategoricalNet,
    GaussianNet,
    get_num_actions,
)
  • self.critic的定义用到了CriticHead(),该网络在自身文件policy.py中定义
class CriticHead(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc = nn.Linear(input_size, 1)
        # self.fc2 = nn.Sequential(
        #     nn.Linear(input_size, 2*input_size),
        #     nn.ReLU(),
        #     nn.Linear(2*input_size, input_size),
        #     nn.ReLU(),
        #     nn.Linear(input_size, 1)
        # )
        nn.init.orthogonal_(self.fc.weight)
        nn.init.constant_(self.fc.bias, 0)

    def forward(self, x):
        return self.fc(x)
  • self.aux_loss_modules
    @classmethod
    def from_config(
        cls,
        config: Config,
        observation_space: spaces.Dict,
        action_space,
        **kwargs,
    ):
        return cls(
            observation_space=observation_space,
            action_space=action_space,
            hidden_size=config.RL.PPO.hidden_size,
            aux_loss_config=config.RL.auxiliary_losses,
        )

这个是来自PointNavBaselinePolicy类的from_config方法。其中aux_loss_config=config.RL.auxiliary_losses来自于config文件

  auxiliary_losses:
    cpca:
      future_subsample: 2
      k: 20
      loss_scale: 0.1
      time_subsample: 6
    enabled: []

python中的cls表示类自己,是类方法的第一个参数。self是类实例方法的第一个参数

3 loss定义

下面的代码定义了self.agent ,该代码在......./habitat_baselines/rl/ppo/ppo_trainer.py_setup_actor_critic_agent方法中

self.agent = (DDPPO if self._is_distributed else PPO).from_config(
    self.actor_critic, ppo_cfg
)

就看需不需要是分布式的PPO,如果不是分布式,那

self.agent =PPO.from_config( self.actor_critic, ppo_cfg)

一般from_config返回PPO类本身

......./habitat_baselines/rl/ppo/ppo_trainer.pytrain()方法中定义

losses = self._update_agent()

_update_agent()方法中

self.agent.train()

4 数据收集方法以及数据如何注入网络

  • 47
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
"Habitat模型下载"是指下载Habitat模型所使用的方法和步骤。Habitat模型是一个用于模拟和预测物种和生态系统分布的工具。 首先,我们需要找到可供下载的Habitat模型。可以通过在科学研究网站、数据库或在线数据存储库中搜索Habitat模型来获取相关信息。一些常用的资源包括生物多样性数据存储库(如GBIF)和生态学期刊(如Ecological Modelling)。 一旦找到合适的Habitat模型,我们可以在相关网站或数据库上找到该模型的下载选项。通常,模型的下载选项可以在模型的页面上找到,并且可能需要用户注册或登录。 在击下载按钮后,可能会要求选择适合的文件格式(如CSV、XML或二进制格式)。选择最适合自己需求的格式,并击下载。 下载完成后,用户可以根据自己的需要对模型进行调整和改进。通常,模型可以使用特定的软件或编程语言进行操作和分析。用户应熟悉相应的软件或编程语言,并按照文档中的指示进行操作。 最后,用户可以将模型应用于特定的研究项目或管理策略。使用Habitat模型可以帮助预测和理解物种的分布模式,评估不同环境因素对生态系统的影响,并为保护和管理自然资源提供决策支持。 总结起来,Habitat模型的下载包括搜索合适的模型资源、选择适当的文件格式、下载模型、进行调整和改进以及应用于特定的研究项目或管理策略。这个过程可能需要一些基本的科学和计算机技能,并需要用户根据自己的需求进行适当的操作和分析。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值