读宇树示例

按照套路初始化Task→leggedrobot(对应机器人的类,就像b1z1),并由此初始化环境envs,初始化onpolicyrunner
然后learn

get_observations得到4096,48的obs,这是compute_observations得到的,纯电机数据和记忆数据组成
其中:

base_lin_vel 			4096,3	基座线速度
base_ang_vel 			4096,3	基座角速度
projected_gravity 		4096,3	重力向量
commands前三维			4096,3	迭代初始给定随机速度方向
dof_pos-default_dof_pos	4096,12	关节位置偏移量
dof_vel					4096,12	关节速度
actions					4096,12

共1500个iter
每iter内每个环境24step(4096个env)

ActorCritic(
  # 得到self.transition.actions
  输入obs
  (actor): Sequential(
    (0): Linear(in_features=48, out_features=512, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=128, out_features=12, bias=True)
  )
  得到4096,12 actions
  
  # 得到self.transition.values
  一开始输入的critic_obs就是obs
  (critic): Sequential(
    (0): Linear(in_features=48, out_features=512, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=128, out_features=1, bias=True)
  )
  得到values 4096,1
)


在这个代码片段中,损失函数(loss)是在强化学习中的Proximal Policy Optimization(PPO)算法中使用的。这个损失函数由以下三个部分组成:策略损失(Surrogate Loss)价值函数损失(Value Function Loss)熵损失(Entropy Loss)。这些部分被组合在一起,以确保策略的更新能够平衡探索和利用,并且避免过度的策略更新。下面是对每个部分的详细分析:

1. Surrogate Loss(代理损失)

  • 定义: 代理损失是PPO中最重要的部分,用于限制策略更新的幅度,避免策略在每次更新时发生过大的变化。

  • 计算:

    ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
    surrogate = -torch.squeeze(advantages_batch) * ratio
    surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param)
    surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
    
    • ratio: 这是当前策略与旧策略的概率比率。
    • surrogate: 原始的损失项,通过 advantages_batch 调整比率。
    • surrogate_clipped: 通过剪切比率(ratio),避免策略更新幅度过大。
    • surrogate_loss: 使用原始和剪切后的损失项中的最大值来计算最终的损失。
  • 目的: 通过 clip_param 参数,PPO 能够限制策略的更新幅度,从而确保策略不会偏离太远,这有助于保持稳定的训练过程。

2. Value Function Loss(价值函数损失)

  • 定义: 价值函数损失用于最小化策略的估计价值(value_batch)与目标值(returns_batch)之间的差异。

  • 计算:

    if self.use_clipped_value_loss:
        value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param, self.clip_param)
        value_losses = (value_batch - returns_batch).pow(2)
        value_losses_clipped = (value_clipped - returns_batch).pow(2)
        value_loss = torch.max(value_losses, value_losses_clipped).mean()
    else:
        value_loss = (returns_batch - value_batch).pow(2).mean()
    
    • value_loss: 当使用 use_clipped_value_loss 时,通过剪切策略更新,限制价值函数的更新幅度,以避免价值函数更新过度。
    • 如果不使用剪切,直接最小化估计值与目标值的均方误差。
  • 目的: 价值函数损失的目标是让策略能够准确地估计未来的回报,从而在策略更新时具备更好的决策依据。

3. Entropy Loss(熵损失)

  • 定义: 熵损失鼓励策略保持探索性,即使在训练过程中,策略也不会过早收敛到某个确定的动作上。

  • 计算:

    entropy_loss = - self.entropy_coef * entropy_batch.mean()
    
    • entropy_loss: 这里的熵损失通过 entropy_coef 调整其在总损失中的权重。
  • 目的: 熵正则化项有助于增加策略的随机性,防止过拟合,并鼓励策略在早期阶段进行更多的探索。

4. 总损失函数

  • 定义: 总损失函数是上述三个损失项的加权和。
  • 计算:
    loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
    
    • surrogate_loss: 主要的策略损失,影响策略的更新方向和幅度。
    • value_loss: 价值函数损失,调整策略对未来回报的估计。
    • entropy_loss: 熵损失,增加探索性。
  • 目的: 通过加权组合这三个部分的损失,PPO 算法能够在策略更新时平衡多个目标,既能提升策略性能,又能保持稳定的训练过程。

总结

  • 策略损失(Surrogate Loss) 用于确保策略更新稳定,避免过度调整。
  • 价值函数损失(Value Function Loss) 用于保证策略对未来回报的准确估计。
  • 熵损失(Entropy Loss) 用于保持策略的探索性,避免陷入局部最优解。

最终的 loss 是这三个损失的组合,代表了在每次更新中我们希望优化的目标。

play

def export_policy_as_jit(actor_critic, path):
    if hasattr(actor_critic, 'memory_a'):
        # assumes LSTM: TODO add GRU
        exporter = PolicyExporterLSTM(actor_critic)
        exporter.export(path)
    else: 
        os.makedirs(path, exist_ok=True)
        path = os.path.join(path, 'policy_1.pt')
        model = copy.deepcopy(actor_critic.actor).to('cpu')
        traced_script_module = torch.jit.script(model)
        traced_script_module.save(path)

这个函数 export_policy_as_jit 主要用于将给定的策略模型(actor_critic)导出为 TorchScript 的格式,以便后续在部署或推理过程中加载和使用。

具体的功能分析:

  1. LSTM 的处理

    • 如果 actor_critic 对象有 memory_a 属性,假设这是一个基于 LSTM 的模型,使用 PolicyExporterLSTM 类来导出 LSTM 模型。由于 LSTM 模型可能涉及到状态的管理,所以需要一个特殊的导出过程。
    • PolicyExporterLSTM: 这里并未展示这个类的具体实现,但它显然是一个专门用于导出带有 LSTM(可能也会扩展到 GRU)的策略的工具类。它将 actor_critic 模型导出到指定的路径。
  2. 普通模型的处理(非 LSTM 模型)

    • 如果 actor_critic 没有 memory_a 属性,意味着这是一个普通的模型(例如全连接网络或卷积网络),那么直接导出模型为 TorchScript 格式。
    • 具体步骤包括:
      • 使用 os.makedirs 创建导出路径(如果路径不存在)。
      • actor_critic.actor 模型复制一份,移动到 CPU 设备上(假设是在 GPU 上训练,但导出需要在 CPU 上进行)。
      • 通过 torch.jit.script 将模型转换为 TorchScript 格式。
      • 最终使用 traced_script_module.save(path) 将转换后的模型保存到指定路径。

torch.jit.scripttorch.jit.trace 的区别:

  • 这里使用了 torch.jit.script 而不是 torch.jit.trace,表明这个模型不仅包含前向计算逻辑,还可能包含条件分支或循环等复杂控制流。torch.jit.script 会解析整个模型的代码,而 torch.jit.trace 只记录实际运行时的计算图。

参数说明:

  • actor_critic: 一个策略模型,可能是强化学习中的 actor-critic 模型。
  • path: 目标路径,用于保存导出的策略模型。

关键逻辑:

  • 判断是否为 LSTM:通过 hasattr(actor_critic, 'memory_a') 来检查模型是否含有 LSTM 组件,并决定导出逻辑的不同处理方式。
  • 普通模型导出:非 LSTM 模型直接使用 TorchScript(torch.jit.script)导出,并保存到 .pt 文件中。

代码用途:

这个函数用于强化学习模型的导出,特别是为了在部署阶段使用更高效的 TorchScript 格式。TorchScript 模型可以在生产环境中高效运行,因为它是优化过的静态图形式,适合移动设备、嵌入式设备和服务器上进行推理。

class PolicyExporterLSTM(torch.nn.Module):
    def __init__(self, actor_critic):
        super().__init__()
        self.actor = copy.deepcopy(actor_critic.actor)
        self.is_recurrent = actor_critic.is_recurrent
        self.memory = copy.deepcopy(actor_critic.memory_a.rnn)
        self.memory.cpu()
        self.register_buffer(f'hidden_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size))
        self.register_buffer(f'cell_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size))

    def forward(self, x):
        out, (h, c) = self.memory(x.unsqueeze(0), (self.hidden_state, self.cell_state))
        self.hidden_state[:] = h
        self.cell_state[:] = c
        return self.actor(out.squeeze(0))

    @torch.jit.export
    def reset_memory(self):
        self.hidden_state[:] = 0.
        self.cell_state[:] = 0.
 
    def export(self, path):
        os.makedirs(path, exist_ok=True)
        path = os.path.join(path, 'policy_lstm_1.pt')
        self.to('cpu')
        traced_script_module = torch.jit.script(self)
        traced_script_module.save(path)

这个 PolicyExporterLSTM 类的作用是将包含 LSTM 结构的策略模型导出为 TorchScript 格式的 .pt 文件,以便在推理或部署过程中高效加载和执行。让我们详细分析这个类的结构和功能:

1. 初始化函数 __init__

  • self.actor:复制并保存 actor_critic 的策略网络(actor),这是最终要导出的模型的策略部分。
  • self.is_recurrent:记录模型是否是一个递归网络(即是否含有 LSTM 或其他 RNN 结构)。
  • self.memory:复制并保存 actor_critic.memory_a.rnn,即模型中用于记忆的 LSTM 组件。并且将其移动到 CPU 上,因为导出时通常使用 CPU。
  • self.hidden_stateself.cell_state:使用 register_buffer 创建了两个状态缓冲区,用来保存 LSTM 的隐状态和单元状态。缓冲区的尺寸由 LSTM 的层数和隐藏状态的尺寸确定。

2. 前向传播 forward

  • 输入 x:假定输入是一个批次的观测值。
  • LSTM 计算:输入 x 被转换为三维张量,通过 self.memory(即 LSTM 网络)进行前向传播,并获得新的隐状态 h 和单元状态 c
  • 状态更新:LSTM 的隐状态和单元状态会在每次前向传播后更新并存储到缓冲区 self.hidden_stateself.cell_state 中,保证在下一次调用时可以继续使用这些状态。
  • 策略网络计算:LSTM 输出的结果 out 会传递给 self.actor,并生成最终的策略输出。

3. reset_memory 方法

  • 作用:提供一个方法用于重置 LSTM 的隐状态和单元状态。在推理过程中,可以调用这个方法来清空 LSTM 的记忆,使模型从零开始。
  • 实现:简单地将 hidden_statecell_state 全部置零。

4. export 方法

  • 作用:将整个 LSTM 策略模型导出为 TorchScript 格式的 .pt 文件。
  • 实现步骤
    1. 创建保存路径,如果路径不存在则新建。
    2. 将模型移动到 CPU 上。
    3. 使用 torch.jit.script 将模型转换为 TorchScript 格式。这是因为模型中含有复杂的控制流(如 LSTM 状态更新),需要完整地记录模型的行为。
    4. 保存为 .pt 文件。

5. TorchScript 支持

  • 通过 torch.jit.export 注解的 reset_memory 方法可以直接在导出的 TorchScript 模型中调用。这意味着在推理过程中,你可以通过调用 reset_memory 重置 LSTM 的状态。

总结:

  • PolicyExporterLSTM 是一个用于导出基于 LSTM 的强化学习策略模型的工具类。
  • 它将包含 LSTM 的策略模型及其隐状态管理(hidden_statecell_state)封装起来,并允许将这些模型导出为 TorchScript 格式(.pt 文件),以便在推理过程中加载使用。
  • 这个导出的模型可以高效地运行,并且能够在推理过程中保持 LSTM 的记忆状态,或通过调用 reset_memory 方法清空这些记忆。

这个类的设计使得在推理或部署阶段,LSTM 的递归状态和策略部分都能高效地管理,并通过 TorchScript 格式提高运行效率。
在推理(playing)阶段而非训练阶段导出模型,并选择使用 TorchScript 代替直接使用 torch.save,是为了满足推理场景的特殊需求并提升效率。这里是导出 checkpoint 时选择 TorchScript 而非简单 torch.save 的原因:

1. 推理时使用 TorchScript 更高效

  • 优化性能TorchScript 是 PyTorch 提供的一种将模型转换为静态图的机制,它比原生的 PyTorch 动态计算图(Eager Execution)更高效。在推理时,模型不再需要每次都动态构建计算图,而是直接执行优化后的静态图,这可以显著提高运行速度,尤其是在资源受限的环境中(如移动设备或嵌入式系统)。
  • 跨平台执行:导出的 TorchScript 模型可以在没有 PyTorch 环境的地方运行,比如 C++ 程序或其他无需 Python 解释器的环境。这对于部署到生产环境非常有用。

2. 推理的场景需求

  • 轻量化部署:在推理过程中,模型不需要像训练时那样频繁地进行梯度计算和更新,模型的参数和结构在推理阶段是固定的。因此,可以导出模型为 TorchScript 格式,它是一种更加轻量级的模型表示形式,适合用于推理或部署。
  • 冻结计算图:在推理过程中,通常不需要反向传播和梯度信息,所以使用 TorchScript 导出的模型会去掉训练相关的部分,减少模型的复杂性和资源消耗。

3. torch.savetorch.jit.script 的区别

  • torch.save:

    • 这是 PyTorch 的标准序列化方法,用于保存模型的状态字典(state_dict)或整个模型对象。它保存的是 Python 对象。
    • 在保存的模型上,依然需要在推理时使用 PyTorch 动态图框架来重新加载模型,然后通过 Python 环境运行。
    • 优点是简单直观,适合在训练过程中保存模型的快照(checkpoint)以供以后继续训练。
  • torch.jit.script:

    • 它将 PyTorch 模型(包括模型的控制流和张量计算)转换为 TorchScript 形式,这种形式的模型可以脱离 Python 环境运行。
    • 跨平台:TorchScript 可以直接在 C++ 或其他不依赖 Python 的环境中运行,使得模型更容易在生产环境中部署。
    • 优化后的性能:TorchScript 模型在推理时具有更高的性能和效率,因为它是静态的、提前编译好的执行图。

4. 为什么不直接用 torch.save

  • 简化推理流程:在推理时,TorchScript 模型可以直接加载并运行,不需要依赖 Python 环境进行模型加载、解释和执行。torch.save 保存的模型是 Python 对象的序列化表示,加载时需要完整的 PyTorch 环境。
  • 部署需求torch.jit.script 导出的模型更适合生产部署,因为它可以在 C++ 环境下直接使用,而 torch.save 保存的模型只能在有 Python 和 PyTorch 环境的系统中使用。
  • 去除动态依赖torch.save 保存的模型保留了所有训练过程中的依赖,而在推理时,这些信息是多余的。通过 TorchScript,可以将模型的计算图静态化,去除多余的依赖,使推理更加高效。

总结:

  • 推理时使用 TorchScript 导出模型 可以提供更高的推理效率和跨平台的可移植性,这是在生产环境中非常有用的需求。
  • torch.save 适合在训练中保存模型的中间状态或用于继续训练的 checkpoint,而在推理和部署阶段,使用 TorchScript 导出的模型更加轻量化、跨平台,并且优化了执行性能,适合高效的推理场景。
  • 19
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值