本文对habitat环境中的baseline点导航PPO的模型的构建和训练进行总结
0 训练代码
这个代码在上一篇文章出现过,再粘贴过来,如下:
import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torch
if __name__ == "__main__":
run_type = "train"
config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")
config.defrost()
config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"
config.freeze()
random.seed(config.TASK_CONFIG.SEED)
np.random.seed(config.TASK_CONFIG.SEED)
torch.manual_seed(config.TASK_CONFIG.SEED)
if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
torch.set_num_threads(1)
trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)
print('trainer_init:',trainer_init)
assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported"
trainer = trainer_init(config)
if run_type == "train":
trainer.train()
elif run_type == "eval":
trainer.eval()
1 trainer
在配置文件中要配置TRAINER_NAME,上面代码中的config.TRAINER_NAME是ppo,然后通过
trainer = trainer_init(config)
这句话,trainer就成了一个rl.ppo.ppo_trainer.PPOTrainer
对象。在......./habitat_baselines/rl/ppo/ppo_trainer.py
中定义
2 训练过程
我们看到训练过程调用了rl.ppo.ppo_trainer.PPOTrainer
中的train()
方法
trainer.train()
TODO:对train()
方法进行分析
3 模型结构定义
PPO是actor_critic结构,需要两个网络一个actor网络,一个critic网络。这两个网络可以共享参数也可以不共享参数。habitat中的ppo在特征提取阶段采用了参数共享,然后分出了两个头。
@baseline_registry.register_trainer(name="ddppo")
@baseline_registry.register_trainer(name="ppo")
class PPOTrainer(BaseRLTrainer):
r"""Trainer class for PPO algorithm
Paper: https://arxiv.org/abs/1707.06347.
"""
supported_tasks = ["Nav-v0"]
SHORT_ROLLOUT_THRESHOLD: float = 0.25
_is_distributed: bool
envs: VectorEnv
agent: PPO
actor_critic: NetPolicy
def __init__(self, config=None):
super().__init__(config)
self.actor_critic = None
self.agent = None
self.envs = None
self.obs_transforms = []
self._static_encoder = False
self._encoder = None
self._obs_space = None
# Distributed if the world size would be
# greater than 1
self._is_distributed = get_distrib_size()[2] > 1
上面的代码可以看出self.actor_critic是NetPolicy类型,而NetPolicy类在......./habitat_baselines/rl/ppo/policy.py
中定义。在......./habitat_baselines/rl/ppo/ppo_trainer.py
中的_setup_actor_critic_agent()
方法中
policy = baseline_registry.get_policy(self.config.RL.POLICY.name)
......
self.actor_critic = policy.from_config(
self.config,
observation_space,
self.policy_action_space,
orig_action_space=self.orig_policy_action_space,
)
定义了self.actor_critic,如上。
1、policy是哪个policy是由self.config.RL.POLICY.name来决定的(类名),举例如下:
name: PointNavResNetPolicy
这个policy在......./habitat_baselines/rl/ddppo/policy/resnet_policy.py
中进行了定义,其中也对网络结构进行了定义。
举例二:
name: PointNavBaselinePolicy
这个policy直接在......./habitat_baselines/rl/ppo/policy.py
中进行了定义
......./habitat_baselines/rl/ppo/policy.py
文件中的NetPolicy类很关键,定义了很多关键的东西()。其他policy(比如上面谈到的PointNavResNetPolicy(其他文件中定义)和PointNavBaselinePolicy)都是继承自他。
NetPolicy类的初始化函数的第一个参数是net,这个net由继承他的类来定义
def __init__(
self, net, action_space, policy_config=None, aux_loss_config=None
):
比如:
@baseline_registry.register_policy
class PointNavBaselinePolicy(NetPolicy):
def __init__(
self,
observation_space: spaces.Dict,
action_space,
hidden_size: int = 512,
aux_loss_config=None,
**kwargs,
):
super().__init__(
PointNavBaselineNet( # type: ignore
observation_space=observation_space,
hidden_size=hidden_size,
**kwargs,
),
action_space=action_space,
aux_loss_config=aux_loss_config,
)
其中的super().init()的第一个参数
PointNavBaselineNet( # type: ignore
observation_space=observation_space,
hidden_size=hidden_size,
**kwargs,
)
就通过父类(NetPolicy)的__init__()函数进行了初始化。而例子中的PointNavBaselineNet在自身文件中进行了定义(其他网络需要在其他文件中定义):
class PointNavBaselineNet(Net):
r"""Network which passes the input image through CNN and concatenates
goal vector with CNN's output and passes that through RNN.
"""
def __init__(
self,
observation_space: spaces.Dict,
hidden_size: int,
):
super().__init__()
if (
IntegratedPointGoalGPSAndCompassSensor.cls_uuid
in observation_space.spaces
):
self._n_input_goal = observation_space.spaces[
IntegratedPointGoalGPSAndCompassSensor.cls_uuid
].shape[0]
elif PointGoalSensor.cls_uuid in observation_space.spaces:
self._n_input_goal = observation_space.spaces[
PointGoalSensor.cls_uuid
].shape[0]
elif ImageGoalSensor.cls_uuid in observation_space.spaces:
goal_observation_space = spaces.Dict(
{"rgb": observation_space.spaces[ImageGoalSensor.cls_uuid]}
)
self.goal_visual_encoder = SimpleCNN(
goal_observation_space, hidden_size
)
self._n_input_goal = hidden_size
self._hidden_size = hidden_size
self.visual_encoder = SimpleCNN(observation_space, hidden_size)
self.state_encoder = build_rnn_state_encoder(
(0 if self.is_blind else self._hidden_size) + self._n_input_goal,
self._hidden_size,
)
self.train()
@property
def output_size(self):
return self._hidden_size
@property
def is_blind(self):
return self.visual_encoder.is_blind
@property
def num_recurrent_layers(self):
return self.state_encoder.num_recurrent_layers
@property
def perception_embedding_size(self):
return self._hidden_size
def forward(
self,
observations,
rnn_hidden_states,
prev_actions,
masks,
rnn_build_seq_info: Optional[Dict[str, torch.Tensor]] = None,
):
aux_loss_state = {}
if IntegratedPointGoalGPSAndCompassSensor.cls_uuid in observations:
target_encoding = observations[
IntegratedPointGoalGPSAndCompassSensor.cls_uuid
]
elif PointGoalSensor.cls_uuid in observations:
target_encoding = observations[PointGoalSensor.cls_uuid]
elif ImageGoalSensor.cls_uuid in observations:
image_goal = observations[ImageGoalSensor.cls_uuid]
target_encoding = self.goal_visual_encoder({"rgb": image_goal})
x = [target_encoding]
if not self.is_blind:
perception_embed = self.visual_encoder(observations)
x = [perception_embed] + x
aux_loss_state["perception_embed"] = perception_embed
x_out = torch.cat(x, dim=1)
x_out, rnn_hidden_states = self.state_encoder(
x_out, rnn_hidden_states, masks, rnn_build_seq_info
)
aux_loss_state["rnn_output"] = x_out
return x_out, rnn_hidden_states, aux_loss_state
NetPolicy类中最重要的是act方法,该方法定义了网络的整体结构。
def act(
self,
observations,
rnn_hidden_states,
prev_actions,
masks,
deterministic=False,
envs_to_pause=None,
rtgs=None,
):
if envs_to_pause is not None:
features, rnn_hidden_states, _ = self.net(
observations, rnn_hidden_states, prev_actions, masks, envs_to_pause=envs_to_pause, rtgs=rtgs,
)
else:
features, rnn_hidden_states, _ = self.net(
observations, rnn_hidden_states, prev_actions, masks,
)
distribution = self.action_distribution(features)
value = self.critic(features)
# self.cur_skill = torch.argmax(self.aux_head(features), dim=-1)
if deterministic:
if self.action_distribution_type == "categorical":
action = distribution.mode()
elif self.action_distribution_type == "gaussian":
action = distribution.mean
elif self.action_distribution_type == "mixed":
action = distribution.mean()
else:
action = distribution.sample()
action_log_probs = distribution.log_probs(action)
return value, action, action_log_probs, rnn_hidden_states
解析:
self.net()是actor网络和critic网络共享参数的部分。下面两句话使用了不同的分支。
distribution = self.action_distribution(features)
value = self.critic(features)
其中self.action_distribution是动作分支,self.critic是计算value的分支,这两个部分在NetPolicy类的初始化函数中定义:
def __init__(
self, net, action_space, policy_config=None, aux_loss_config=None
):
super().__init__()
self.net = net
self.dim_actions = get_num_actions(action_space)
self.action_distribution: Union[CategoricalNet, GaussianNet]
if policy_config is None:
self.action_distribution_type = "categorical"
else:
self.action_distribution_type = (
policy_config.action_distribution_type
)
if self.action_distribution_type == "categorical":
self.action_distribution = CategoricalNet(
self.net.output_size, self.dim_actions
)
elif self.action_distribution_type == "gaussian":
self.action_distribution = GaussianNet(
self.net.output_size,
self.dim_actions,
policy_config.ACTION_DIST,
)
elif self.action_distribution_type == "mixed":
self.action_distribution = MixedDistributionNet(
self.net.output_size,
policy_config.ACTION_DIST,
action_space,
)
else:
ValueError(
f"Action distribution {self.action_distribution_type}"
"not supported."
)
self.critic = CriticHead(self.net.output_size)
self.aux_loss_modules = nn.ModuleDict()
for aux_loss_name in (
() if aux_loss_config is None else aux_loss_config.enabled
):
aux_loss = baseline_registry.get_auxiliary_loss(aux_loss_name)
self.aux_loss_modules[aux_loss_name] = aux_loss(
action_space,
self.net,
**getattr(aux_loss_config, aux_loss_name),
)
可以看到初始化函数主要对self.action_distribution
、self.critic
、self.aux_loss_modules
进行了定义。
self.action_distribution
的 定义需要考虑动作类别是离散还是连续还是混合。不同类别分别用到了 CategoricalNet、GaussianNet、MixedDistributionNet,而这些网络在habitat_baselines.utils.common
中有定义。
from habitat_baselines.utils.common import (
CategoricalNet,
GaussianNet,
get_num_actions,
)
self.critic
的定义用到了CriticHead(),该网络在自身文件policy.py
中定义
class CriticHead(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1)
# self.fc2 = nn.Sequential(
# nn.Linear(input_size, 2*input_size),
# nn.ReLU(),
# nn.Linear(2*input_size, input_size),
# nn.ReLU(),
# nn.Linear(input_size, 1)
# )
nn.init.orthogonal_(self.fc.weight)
nn.init.constant_(self.fc.bias, 0)
def forward(self, x):
return self.fc(x)
self.aux_loss_modules
@classmethod
def from_config(
cls,
config: Config,
observation_space: spaces.Dict,
action_space,
**kwargs,
):
return cls(
observation_space=observation_space,
action_space=action_space,
hidden_size=config.RL.PPO.hidden_size,
aux_loss_config=config.RL.auxiliary_losses,
)
这个是来自PointNavBaselinePolicy类的from_config方法。其中aux_loss_config=config.RL.auxiliary_losses来自于config文件
auxiliary_losses:
cpca:
future_subsample: 2
k: 20
loss_scale: 0.1
time_subsample: 6
enabled: []
python中的cls表示类自己,是类方法的第一个参数。self是类实例方法的第一个参数
3 loss定义
下面的代码定义了self.agent
,该代码在......./habitat_baselines/rl/ppo/ppo_trainer.py
的_setup_actor_critic_agent
方法中
self.agent = (DDPPO if self._is_distributed else PPO).from_config(
self.actor_critic, ppo_cfg
)
就看需不需要是分布式的PPO,如果不是分布式,那
self.agent =PPO.from_config( self.actor_critic, ppo_cfg)
一般from_config返回PPO类本身
......./habitat_baselines/rl/ppo/ppo_trainer.py
的train()
方法中定义
losses = self._update_agent()
_update_agent()
方法中
self.agent.train()