DI-engine强化学习入门(七)如何自定义神经网络模型

在强化学习中,需要根据决策问题和策略选择合适的神经网络。DI-engine中,神经网络模型可以通过两种方式指定:

  1. 使用配置文件中的cfg.policy.model自动生成默认模型。这种方式下,可以在配置文件中指定神经网络的类型(MLP、CNN等)以及超参数(隐层大小、激活函数等),DI-engine会根据这些配置自动构建神经网络模型。这种方式简单易用,适用于常见的标准网络结构。
  2. 自定义模型实例并传入Policy。这种方式下,需要用户自己定义Tensorflow/Pytorch模型类,实现前向传播等接口,然后将实例传入Policy中。这种方式灵活度高,用户可以自由设计任意结构的神经网络。但是需要用户比较熟悉网络定义和 Tensorflow/Pytorch接口。

(注:在强化学习中,策略(Policy)是指智能体(Agent)决策的规则。策略是从状态(State)到动作(Action)的映射,它定义了在给定的状态下,智能体应该采取什么动作。策略可以是确定性的(Deterministic)也可以是随机性的(Stochastic)。)
以上两种方式都会在Policy中封装为neural_net属性,策略学习会通过这个网络完成状态的 embedding 以及动作的选择。这套机制和接口为用户提供了必要的灵活性,可以根据具体问题和需求配置各种神经网络模型。
这是红色的文字

Policy 默认使用的模型是什么
DI-engine 中已经实现的 policy,默认使用 default_model 方法中表明的神经网络模型,例如在 SACPolicy 中:

@POLICY_REGISTRY.register('sac')class SACPolicy(Policy):...    def default_model(self) -> Tuple[str, List[str]]:        if self._cfg.multi_agent:            return 'maqac_continuous', ['ding.model.template.maqac']        else:            return 'qac', ['ding.model.template.qac']...

在这段代码中,我们看到的是一个名为 DI-engine 的强化学习框架中的一个策略(Policy)类的一部分。具体来说,它定义了一个使用Soft Actor-Critic, SAC 算法的策略类。这个段落描述了如何在这个框架内设置和使用策略相关的神经网络模型。

让我们逐步解释这段代码:

  1. @POLICY_REGISTRY.register('sac') 是一个装饰器,它将 SACPolicy 类注册到一个名为 POLICY_REGISTRY 的注册器中,并且用 'sac' 作为这个策略的标识符。这样的注册机制允许框架能够根据名字轻松地查找和实例化策略。
  2. class SACPolicy(Policy): 表明 SACPolicy 是从更一般的 Policy 类派生的,它是一个具体的策略实现,使用了 SAC 算法。
  3. default_model 是 SACPolicy 类的一个方法,它定义了该策略默认使用的模型。这个方法返回两个元素的元组:
  • 'maqac_continuous' 或 'qac':这是在模型注册器中注册的模型名字。根据配置是否是多智能体(multi_agent),它返回不同的模型名。
  • ['ding.model.template.maqac'] 或 ['ding.model.template.qac']:这是包含模型类的文件路径的列表。这个路径告诉 DI-engine 在哪里可以找到定义模型的代码。

4.当使用配置文件时,DI-engine 的入口文件将使用 cfg.policy.model 中的参数(比如 obs_shape, action_shape)来实例化提供的模型类。这个过程是自动化的,意味着用户定义好配置,DI-engine 将负责根据这些配置创建并初始化模型。

5.模型类会根据传入的参数构造适当的神经网络。例如,如果传入的 obs_shape 参数表明观测是一个图像,则模型可能会使用卷积层来处理输入;如果观测是一个向量,则可能使用全连接层。

简而言之,这段代码展示了 DI-engine 如何灵活地处理不同类型的策略和模型,以及如何通过配置文件来方便地自定义和实例化这些策略和模型。这种设计允许研究者和开发者能够轻松试验不同的算法和模型架构,而无需直接修改代码。

如何自定义神经网络模型

在 DI-engine 强化学习框架中,每个策略(如 SACPolicy)通常有一个关联的默认模型(通过 default_model 方法指定),这个默认模型是为特定类型的任务设计的。例如,原始的 qac 模型可能是为处理具有一维观测空间的环境设计的,即观测是一个向量。

但是!如果任务是在一个模拟器(如 dmc2gym,一个DeepMind Control Suite到OpenAI Gym接口的适配器)上运行,并且任务是 cartpole-swingup,而且你希望使用观测为像素的输入(即观测是一个图像),那么默认的 qac 模型不足以处理这样的高维度和多通道的输入。在这种情况下,观测空间的形状是 (3, height, width),其中 3 表示图像的颜色通道数(RGB),height 和 width 分别表示图像的高度和宽度。

在 dmc2gym 文档中,from_pixel 参数设定为 True 意味着环境将提供像素级的观测,而 channels_first 设定为 True 表明图像的通道维度是第一维(这是PyTorch等深度学习库通常采用的格式)。

面对这样的情况,如果你想要使用 SAC 算法处理像素级的观测,你需要自定义一个能够处理这种高维观测的模型。所以我们创建一个新的模型类,该类在内部使用卷积神经网络(CNN)来处理输入的图像数据,并适当地修改网络架构以适应任务的特定要求。

自定义模型完成后,可以将这个模型应用到 SACPolicy 中,替换原本的 qac 模型。涉及到以下几个步骤:

  1. 实现一个新的模型类,它继承自某个基础模型类,并覆盖必要的方法以支持像素级输入。
  2. 在策略配置中指定你的自定义模型,以便 DI-engine 使用你提供的模型而不是默认模型。
  3. 确保你的自定义模型注册到 DI-engine 的模型注册器中,这样框架可以识别和使用它。

自定义 model 基本步骤
1. 明确环境 (env) 和策略 (policy)
首先,需要确定你的强化学习任务的具体环境和任务。例如,我们选择 dmc2gym 环境中的 cartpole-swingup 任务,并且决定观测将以像素数据的形式提供,我们的观测空间是一个图像,其形状为 (3, height, width)。下面我们使用 SAC 算法来进行学习。

在这里,from_pixel = True 表明环境将提供基于像素的观测,channels_first = True 表明图像数据的通道维度在前,这通常是深度学习库(如 PyTorch)的标准格式。

2. 查阅策略中的 default_model 是否适用
接下来,需要检查选择的策略是否具有适用于任务的默认模型。这可以通过查看策略的文档或直接查阅源代码来完成。以 DI-engine 中的 SAC 策略为例,可以查看 SACPolicy 类中的 default_model 方法来了解默认模型:

如果进一步看一下 ding.model.template.qac 中的 QAC 模型,咱们可能会发现它仅支持一维的观测空间,而不支持像 (3, height, width) 这样的图像形状。这意味着对于我们的 cartpole-swingup 任务,需要创建一个自定义模型来处理像素级的观测。

3. 自定义模型 (custom_model) 实现
自定义模型的实现需要遵循一些基本的原则,以确保与 DI-engine 框架的兼容性。

a. 实现功能
自定义模型需要实现默认模型中的所有公共方法。包括:

  • __init__: 构造函数,对模型的各个部分进行初始化。
  • forward: 定义模型如何从输入到输出的前向传递。
  • compute_actor: 计算策略网络的输出,即给定观测值时的动作。
  • compute_critic: 计算价值网络的输出,即动作的价值。

b. 保持返回值类型一致
自定义模型的方法需要保证与原始默认模型的返回值类型一致,以便于替换使用。

c. 利用已实现的 encoder 和 head
在 ding/model/common 下有多种 encoder 和 head 的实现,可用于构建不同部分的模型:

  • Encoder: 负责对输入数据进行编码,使其适合后续的处理。例如,ConvEncoder 用于处理图像观测输入,FCEncoder 用于处理一维观测输入。

点击DI-engine强化学习入门(七)如何自定义神经网络模型 - 古月居可查看全文

  • 7
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值