【强化学习】stable baselines3中自定义网络的使用(针对Actor-Critic架构)

前言

sb3官方的教程已经给出了自定义网络的使用方式,但是其中给出的例子较为简单,在实际工程中按照官方的例子出现了参数数量错误等各种问题。这里将官方文档进行细化,本文针对Actor-Critic框架下的on-policy算法,如果你看过了官方文档之后仍然不会使用自定义网络,欢迎继续阅读。

源码分析

为了更好了解如何在sb3中使用自定义网络,我们先对源码进行分析。
首先,AC网络基础类ActorCriticPolicy的初始化过程梳理如下(节选重要的部分展示):

# stable_baselines3.common.policies.ActorCriticPolicy
class ActorCriticPolicy(BasePolicy):
	def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        share_features_extractor: bool = True,
    ):
    	...
    	self.features_extractor = self.make_features_extractor()
        self.features_dim = self.features_extractor.features_dim
        if self.share_features_extractor:
            self.pi_features_extractor = self.features_extractor
            self.vf_features_extractor = self.features_extractor
        else:
            self.pi_features_extractor = self.features_extractor
            self.vf_features_extractor = self.make_features_extractor()
            
            self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)

        self._build(lr_schedule)

AC网络分为两部分:特征提取网络和之后的动作和值网络,在sb3中,将特征网络记为feature_extractor,将动作和值网络统一记为mlp_extractor。以下对这两部分分别进行分析:

  1. 在参数部分,features_extractor_classfeatures_extractor_kwargs控制特征提取网络,这两个参数的使用在self.make_features_extractor中:
    def make_features_extractor(self) -> BaseFeaturesExtractor:
        """Helper method to create a features extractor."""
        return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
    
    分析这个函数得知,features_extractor_kwargs中不需要包含observation_space
  2. 动作网络和值网络的初始化放在self._build()中完成(仍然截取重要部分):
    def _build(self, lr_schedule: Schedule) -> None:
    	self._build_mlp_extractor()
    
        latent_dim_pi = self.mlp_extractor.latent_dim_pi
        self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
    
    注意这里在原先的动作和值网络后面加上了一层,这是为了确保正确的输出大小。例如我们对值网络节点的定义是[128,64],sb3会在最后自动加入一层[64,1]来确保输出大小是(batch, 1)。我们聚焦self._build_mlp_extractor()
    def _build_mlp_extractor(self) -> None:
    	self.mlp_extractor = MlpExtractor(
            self.features_dim,
            net_arch=self.net_arch,
            activation_fn=self.activation_fn,
            device=self.device,
        )
    
    可以观察到这里是用到了net_arch参数来创建动作和值网络,我们可以在继承ActorCriticPolicy的子类中重定义这个函数,以实现自定义的动作和值网络。这里可以参考官方给出的示例

使用参数进行自定义类的传递

policy_kwargs参数进行自定义网络传递较为简单,唯一需要注意的是以上分析的,特征提取网络不需要额外传递observation_space,以下给出一个简单的示例:

class MLP(BaseFeaturesExtractor):
    def __init__(
        self,
        observation_space,
        features_dim,
        mlp1_dims,
    ):
        super().__init__(observation_space, features_dim)
        self.mlp1 = mlp(input_dim, mlp1_dims, last_relu=True)

    def forward(self, obs: torch.Tensor):
        features = self.mlp1(obs)
        return features
        
if __name__ == "__main__":
	features_dim = 256
	policy_kwargs = dict(
	    features_extractor_class=MLP,
	    features_extractor_kwargs=dict(features_dim=features_dim, mlp1_dims=[512, features_dim]),
	    activation_fn=nn.ReLU,
	    net_arch=dict(pi=[128, 32], vf=[128, 32]),
	)
	model = PPO(
	    policy='MlpPolicy',
	    env=...,
	    policy_kwargs=policy_kwargs, 
	)

注意在net_arch部分,值网络和动作网络会根据feature_extractor即特征提取网络改变接受的输入的维数,所以不需要使用[features_dim,128,32],直接写[128,32]即可,前者会额外增加一个(in_features=features_dim, out_features=features_dim)的层。

继承Actor-Critic网络

在实际工程中,自定义的网络可以较为复杂,例如动作网络和值网络可以共享某些层,在某些层的位置再分开成为两个网络,这种更加精细的控制,就需要我们继承ActoriCriticPolicy来实现自定义网络来实现。我们分为三部分实现:

  1. 特征提取网络FeatureExtractor
  2. 动作和值网络CustomAC
  3. 整合的完整网络CustomNetwork

首先,特征提取网络部分只需注意在初始化参数表中加入observation_spacefeatures_dim以传递给父类:

from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class FeatureExtractor(BaseFeaturesExtractor):
	def __init__(
        self,
        observation_space,
        features_dim,
        ... # 其他参数
    ):
        super().__init__(observation_space, features_dim)
        # 正常做类初始化

其次,对于CustomAC的设计,参照官方文档,注意实现forward_actorforward_critic即可:

class CustomAC(nn.Module):
    def __init__(
        self,
        features_dim: int,
        last_layer_dim_pi: int = 32,
        last_layer_dim_vf: int = 32,
    ):
        super().__init__()

        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # Policy network
        hidden_units_pi_1 = 64
        self.policy_net = nn.Sequential(
            nn.Linear(features_dim, hidden_units_pi_1),
            nn.ReLU(),
            nn.Linear(hidden_units_pi_1, last_layer_dim_pi),
            nn.ReLU(),
        )
        # Value network
        hidden_units_vf_1 = 64
        self.value_net = nn.Sequential(
            nn.Linear(features_dim, hidden_units_vf_1),
            nn.ReLU(),
            nn.Linear(hidden_units_vf_1, last_layer_dim_vf),
            nn.ReLU(),
        )

    def forward(self, features: th.Tensor):
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)

只需注意sb3会自动在值网络和动作网络最后加入一层,确保输出的大小,所以我们初始化self.value_netself.policy_net时,只需写到倒数第二层。
最后,整合两个网络:

from stable_baselines3.common.policies import ActorCriticPolicy

class CustomNetwork(ActorCriticPolicy):
    def __init__(self, *args, **kwargs):
        features_dim = 100
        super().__init__(
            *args,
            **kwargs, 
            features_extractor_class=FeatureExtractor, 
            features_extractor_kwargs=dict(
                features_dim=features_dim,
                ... # 其他FeatureExtractor的参数
            )
        )

    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomAC(
            self.features_dim,
            last_layer_dim_pi=32,
            last_layer_dim_vf=32,
        )

几个需要注意的点:

  1. _build_mlp_extractor中使用的self.features_dim来自于self.feature_extractor.features_dim,因为特征提取网络先初始化,这也提醒我们,在初始化的时候,尽量不要在super().__init__()后单独初始化self.feature_extractor,而是使用父类初始化提供的feature_extractor参数。
  2. _build_mlp_extractor是在super().__init__()中调用的
  3. 使用的时候,只需model=PPO(CustomNetwork)即可,不需要再传递policy_kwargs

总结

本文说明了stable baselines3中on-policy Actor-Critic网络自定义网络的传递方式,官方文档中提到:

针对off-policy的AC算法,需要深入了解所用的算法,在此基础上使用自定义网络

这是由于SAC等算法,使用了Double-Q网络,对于特征提取网络的共享、值函数和动作函数的前向传播的控制会更加复杂。之后的系列文章将对在几个经典的off-policy算法框架下使用自定义网络进行说明

### Stable-Baselines3 的目录结构及各部分作用 Stable-Baselines3 是一个用于快速实现强化学习算法的 Python 库,基于 PyTorch 构建。以下是其主要目录结构以及各个文件和模块的功能说明: #### 1. `stable_baselines3` 主目录 这是核心代码所在的目录,包含了所有的算法实现和其他支持工具。 - **`common/`**: 这是一个通用组件集合,提供了多个算法共享的基础功能。其中包括环境包装器、策略网络定义以及其他辅助函数。 - `base_class.py`: 抽象基类,定义了所有 RL 算法应遵循的标准接口[^2]。 - `buffers.py`: 实现经验回放缓冲区(Replay Buffer),存储 agent 和 environment 的交互数据以便后续训练。 - `policies.py`: 定义不同类型的神经网络架构作为策略模型的一部分。 - `vec_env/`: 提供向量化环境的支持,允许并行运行多个环境实例来加速采样过程。 - **`algos/` 或具体算法文件 (e.g., `ppo.py`, `dqn.py`)** 每种具体的强化学习方法都有对应的单独文件。这些文件实现了各自的逻辑流程,并继承自 common/base_class 中定义的一般框架。 - PPO: Proximal Policy Optimization, 高效更新策略参数的同时保持稳定性。 - DQN: Deep Q-Networks, 使用深度神经网络近似动作值函数Q(s,a). #### 2. 测试用例 (`tests/`) 该子目录保存着单元测试脚本,用来验证各项功能是否按预期工作。它对于开发者维护高质量代码至关重要。 #### 3. 文档资料 (`docs/`) 文档描述了 API 接口详情、安装指南、教程等内容,帮助用户更好地理解和使用此库。 ```python from stable_baselines3 import PPO model = PPO('MlpPolicy', 'CartPole-v1').learn(total_timesteps=10_000) ``` 上述例子展示了如何利用预设好的 MlpPolicy 来解决 CartPole 游戏问题并通过 learn 方法启动训练循环[^1].
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值