按照套路初始化Task→leggedrobot(对应机器人的类,就像b1z1),并由此初始化环境envs,初始化onpolicyrunner
然后learn
get_observations得到4096,48的obs,这是compute_observations得到的,纯电机数据和记忆数据组成
其中:
base_lin_vel 4096,3 基座线速度
base_ang_vel 4096,3 基座角速度
projected_gravity 4096,3 重力向量
commands前三维 4096,3 迭代初始给定随机速度方向
dof_pos-default_dof_pos 4096,12 关节位置偏移量
dof_vel 4096,12 关节速度
actions 4096,12
共1500个iter
每iter内每个环境24step(4096个env)
ActorCritic(
# 得到self.transition.actions
输入obs
(actor): Sequential(
(0): Linear(in_features=48, out_features=512, bias=True)
(1): ELU(alpha=1.0)
(2): Linear(in_features=512, out_features=256, bias=True)
(3): ELU(alpha=1.0)
(4): Linear(in_features=256, out_features=128, bias=True)
(5): ELU(alpha=1.0)
(6): Linear(in_features=128, out_features=12, bias=True)
)
得到4096,12 actions
# 得到self.transition.values
一开始输入的critic_obs就是obs
(critic): Sequential(
(0): Linear(in_features=48, out_features=512, bias=True)
(1): ELU(alpha=1.0)
(2): Linear(in_features=512, out_features=256, bias=True)
(3): ELU(alpha=1.0)
(4): Linear(in_features=256, out_features=128, bias=True)
(5): ELU(alpha=1.0)
(6): Linear(in_features=128, out_features=1, bias=True)
)
得到values 4096,1
)
在这个代码片段中,损失函数(loss
)是在强化学习中的Proximal Policy Optimization(PPO)算法中使用的。这个损失函数由以下三个部分组成:策略损失(Surrogate Loss)、价值函数损失(Value Function Loss) 和 熵损失(Entropy Loss)。这些部分被组合在一起,以确保策略的更新能够平衡探索和利用,并且避免过度的策略更新。下面是对每个部分的详细分析:
1. Surrogate Loss(代理损失):
-
定义: 代理损失是PPO中最重要的部分,用于限制策略更新的幅度,避免策略在每次更新时发生过大的变化。
-
计算:
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) surrogate = -torch.squeeze(advantages_batch) * ratio surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
ratio
: 这是当前策略与旧策略的概率比率。surrogate
: 原始的损失项,通过advantages_batch
调整比率。surrogate_clipped
: 通过剪切比率(ratio
),避免策略更新幅度过大。surrogate_loss
: 使用原始和剪切后的损失项中的最大值来计算最终的损失。
-
目的: 通过
clip_param
参数,PPO 能够限制策略的更新幅度,从而确保策略不会偏离太远,这有助于保持稳定的训练过程。
2. Value Function Loss(价值函数损失):
-
定义: 价值函数损失用于最小化策略的估计价值(
value_batch
)与目标值(returns_batch
)之间的差异。 -
计算:
if self.use_clipped_value_loss: value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param, self.clip_param) value_losses = (value_batch - returns_batch).pow(2) value_losses_clipped = (value_clipped - returns_batch).pow(2) value_loss = torch.max(value_losses, value_losses_clipped).mean() else: value_loss = (returns_batch - value_batch).pow(2).mean()
value_loss
: 当使用use_clipped_value_loss
时,通过剪切策略更新,限制价值函数的更新幅度,以避免价值函数更新过度。- 如果不使用剪切,直接最小化估计值与目标值的均方误差。
-
目的: 价值函数损失的目标是让策略能够准确地估计未来的回报,从而在策略更新时具备更好的决策依据。
3. Entropy Loss(熵损失):
-
定义: 熵损失鼓励策略保持探索性,即使在训练过程中,策略也不会过早收敛到某个确定的动作上。
-
计算:
entropy_loss = - self.entropy_coef * entropy_batch.mean()
entropy_loss
: 这里的熵损失通过entropy_coef
调整其在总损失中的权重。
-
目的: 熵正则化项有助于增加策略的随机性,防止过拟合,并鼓励策略在早期阶段进行更多的探索。
4. 总损失函数:
- 定义: 总损失函数是上述三个损失项的加权和。
- 计算:
loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
surrogate_loss
: 主要的策略损失,影响策略的更新方向和幅度。value_loss
: 价值函数损失,调整策略对未来回报的估计。entropy_loss
: 熵损失,增加探索性。
- 目的: 通过加权组合这三个部分的损失,PPO 算法能够在策略更新时平衡多个目标,既能提升策略性能,又能保持稳定的训练过程。
总结
- 策略损失(Surrogate Loss) 用于确保策略更新稳定,避免过度调整。
- 价值函数损失(Value Function Loss) 用于保证策略对未来回报的准确估计。
- 熵损失(Entropy Loss) 用于保持策略的探索性,避免陷入局部最优解。
最终的 loss
是这三个损失的组合,代表了在每次更新中我们希望优化的目标。
play
def export_policy_as_jit(actor_critic, path):
if hasattr(actor_critic, 'memory_a'):
# assumes LSTM: TODO add GRU
exporter = PolicyExporterLSTM(actor_critic)
exporter.export(path)
else:
os.makedirs(path, exist_ok=True)
path = os.path.join(path, 'policy_1.pt')
model = copy.deepcopy(actor_critic.actor).to('cpu')
traced_script_module = torch.jit.script(model)
traced_script_module.save(path)
这个函数 export_policy_as_jit
主要用于将给定的策略模型(actor_critic
)导出为 TorchScript 的格式,以便后续在部署或推理过程中加载和使用。
具体的功能分析:
-
LSTM 的处理:
- 如果
actor_critic
对象有memory_a
属性,假设这是一个基于 LSTM 的模型,使用PolicyExporterLSTM
类来导出 LSTM 模型。由于 LSTM 模型可能涉及到状态的管理,所以需要一个特殊的导出过程。 PolicyExporterLSTM
: 这里并未展示这个类的具体实现,但它显然是一个专门用于导出带有 LSTM(可能也会扩展到 GRU)的策略的工具类。它将actor_critic
模型导出到指定的路径。
- 如果
-
普通模型的处理(非 LSTM 模型):
- 如果
actor_critic
没有memory_a
属性,意味着这是一个普通的模型(例如全连接网络或卷积网络),那么直接导出模型为 TorchScript 格式。 - 具体步骤包括:
- 使用
os.makedirs
创建导出路径(如果路径不存在)。 - 将
actor_critic.actor
模型复制一份,移动到 CPU 设备上(假设是在 GPU 上训练,但导出需要在 CPU 上进行)。 - 通过
torch.jit.script
将模型转换为 TorchScript 格式。 - 最终使用
traced_script_module.save(path)
将转换后的模型保存到指定路径。
- 使用
- 如果
torch.jit.script
与 torch.jit.trace
的区别:
- 这里使用了
torch.jit.script
而不是torch.jit.trace
,表明这个模型不仅包含前向计算逻辑,还可能包含条件分支或循环等复杂控制流。torch.jit.script
会解析整个模型的代码,而torch.jit.trace
只记录实际运行时的计算图。
参数说明:
actor_critic
: 一个策略模型,可能是强化学习中的 actor-critic 模型。path
: 目标路径,用于保存导出的策略模型。
关键逻辑:
- 判断是否为 LSTM:通过
hasattr(actor_critic, 'memory_a')
来检查模型是否含有 LSTM 组件,并决定导出逻辑的不同处理方式。 - 普通模型导出:非 LSTM 模型直接使用 TorchScript(
torch.jit.script
)导出,并保存到.pt
文件中。
代码用途:
这个函数用于强化学习模型的导出,特别是为了在部署阶段使用更高效的 TorchScript 格式。TorchScript 模型可以在生产环境中高效运行,因为它是优化过的静态图形式,适合移动设备、嵌入式设备和服务器上进行推理。
class PolicyExporterLSTM(torch.nn.Module):
def __init__(self, actor_critic):
super().__init__()
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
self.memory = copy.deepcopy(actor_critic.memory_a.rnn)
self.memory.cpu()
self.register_buffer(f'hidden_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size))
self.register_buffer(f'cell_state', torch.zeros(self.memory.num_layers, 1, self.memory.hidden_size))
def forward(self, x):
out, (h, c) = self.memory(x.unsqueeze(0), (self.hidden_state, self.cell_state))
self.hidden_state[:] = h
self.cell_state[:] = c
return self.actor(out.squeeze(0))
@torch.jit.export
def reset_memory(self):
self.hidden_state[:] = 0.
self.cell_state[:] = 0.
def export(self, path):
os.makedirs(path, exist_ok=True)
path = os.path.join(path, 'policy_lstm_1.pt')
self.to('cpu')
traced_script_module = torch.jit.script(self)
traced_script_module.save(path)
这个 PolicyExporterLSTM
类的作用是将包含 LSTM 结构的策略模型导出为 TorchScript 格式的 .pt
文件,以便在推理或部署过程中高效加载和执行。让我们详细分析这个类的结构和功能:
1. 初始化函数 __init__
:
self.actor
:复制并保存actor_critic
的策略网络(actor
),这是最终要导出的模型的策略部分。self.is_recurrent
:记录模型是否是一个递归网络(即是否含有 LSTM 或其他 RNN 结构)。self.memory
:复制并保存actor_critic.memory_a.rnn
,即模型中用于记忆的 LSTM 组件。并且将其移动到 CPU 上,因为导出时通常使用 CPU。self.hidden_state
和self.cell_state
:使用register_buffer
创建了两个状态缓冲区,用来保存 LSTM 的隐状态和单元状态。缓冲区的尺寸由 LSTM 的层数和隐藏状态的尺寸确定。
2. 前向传播 forward
:
- 输入
x
:假定输入是一个批次的观测值。 - LSTM 计算:输入
x
被转换为三维张量,通过self.memory
(即 LSTM 网络)进行前向传播,并获得新的隐状态h
和单元状态c
。 - 状态更新:LSTM 的隐状态和单元状态会在每次前向传播后更新并存储到缓冲区
self.hidden_state
和self.cell_state
中,保证在下一次调用时可以继续使用这些状态。 - 策略网络计算:LSTM 输出的结果
out
会传递给self.actor
,并生成最终的策略输出。
3. reset_memory
方法:
- 作用:提供一个方法用于重置 LSTM 的隐状态和单元状态。在推理过程中,可以调用这个方法来清空 LSTM 的记忆,使模型从零开始。
- 实现:简单地将
hidden_state
和cell_state
全部置零。
4. export
方法:
- 作用:将整个 LSTM 策略模型导出为 TorchScript 格式的
.pt
文件。 - 实现步骤:
- 创建保存路径,如果路径不存在则新建。
- 将模型移动到 CPU 上。
- 使用
torch.jit.script
将模型转换为 TorchScript 格式。这是因为模型中含有复杂的控制流(如 LSTM 状态更新),需要完整地记录模型的行为。 - 保存为
.pt
文件。
5. TorchScript 支持:
- 通过
torch.jit.export
注解的reset_memory
方法可以直接在导出的 TorchScript 模型中调用。这意味着在推理过程中,你可以通过调用reset_memory
重置 LSTM 的状态。
总结:
PolicyExporterLSTM
是一个用于导出基于 LSTM 的强化学习策略模型的工具类。- 它将包含 LSTM 的策略模型及其隐状态管理(
hidden_state
和cell_state
)封装起来,并允许将这些模型导出为 TorchScript 格式(.pt
文件),以便在推理过程中加载使用。 - 这个导出的模型可以高效地运行,并且能够在推理过程中保持 LSTM 的记忆状态,或通过调用
reset_memory
方法清空这些记忆。
这个类的设计使得在推理或部署阶段,LSTM 的递归状态和策略部分都能高效地管理,并通过 TorchScript 格式提高运行效率。
在推理(playing)阶段而非训练阶段导出模型,并选择使用 TorchScript
代替直接使用 torch.save
,是为了满足推理场景的特殊需求并提升效率。这里是导出 checkpoint 时选择 TorchScript
而非简单 torch.save
的原因:
1. 推理时使用 TorchScript
更高效
- 优化性能:
TorchScript
是 PyTorch 提供的一种将模型转换为静态图的机制,它比原生的 PyTorch 动态计算图(Eager Execution)更高效。在推理时,模型不再需要每次都动态构建计算图,而是直接执行优化后的静态图,这可以显著提高运行速度,尤其是在资源受限的环境中(如移动设备或嵌入式系统)。 - 跨平台执行:导出的
TorchScript
模型可以在没有 PyTorch 环境的地方运行,比如 C++ 程序或其他无需 Python 解释器的环境。这对于部署到生产环境非常有用。
2. 推理的场景需求
- 轻量化部署:在推理过程中,模型不需要像训练时那样频繁地进行梯度计算和更新,模型的参数和结构在推理阶段是固定的。因此,可以导出模型为
TorchScript
格式,它是一种更加轻量级的模型表示形式,适合用于推理或部署。 - 冻结计算图:在推理过程中,通常不需要反向传播和梯度信息,所以使用
TorchScript
导出的模型会去掉训练相关的部分,减少模型的复杂性和资源消耗。
3. torch.save
和 torch.jit.script
的区别
-
torch.save
:- 这是 PyTorch 的标准序列化方法,用于保存模型的状态字典(
state_dict
)或整个模型对象。它保存的是 Python 对象。 - 在保存的模型上,依然需要在推理时使用 PyTorch 动态图框架来重新加载模型,然后通过 Python 环境运行。
- 优点是简单直观,适合在训练过程中保存模型的快照(checkpoint)以供以后继续训练。
- 这是 PyTorch 的标准序列化方法,用于保存模型的状态字典(
-
torch.jit.script
:- 它将 PyTorch 模型(包括模型的控制流和张量计算)转换为 TorchScript 形式,这种形式的模型可以脱离 Python 环境运行。
- 跨平台:TorchScript 可以直接在 C++ 或其他不依赖 Python 的环境中运行,使得模型更容易在生产环境中部署。
- 优化后的性能:TorchScript 模型在推理时具有更高的性能和效率,因为它是静态的、提前编译好的执行图。
4. 为什么不直接用 torch.save
?
- 简化推理流程:在推理时,
TorchScript
模型可以直接加载并运行,不需要依赖 Python 环境进行模型加载、解释和执行。torch.save
保存的模型是 Python 对象的序列化表示,加载时需要完整的 PyTorch 环境。 - 部署需求:
torch.jit.script
导出的模型更适合生产部署,因为它可以在 C++ 环境下直接使用,而torch.save
保存的模型只能在有 Python 和 PyTorch 环境的系统中使用。 - 去除动态依赖:
torch.save
保存的模型保留了所有训练过程中的依赖,而在推理时,这些信息是多余的。通过TorchScript
,可以将模型的计算图静态化,去除多余的依赖,使推理更加高效。
总结:
- 推理时使用
TorchScript
导出模型 可以提供更高的推理效率和跨平台的可移植性,这是在生产环境中非常有用的需求。 torch.save
适合在训练中保存模型的中间状态或用于继续训练的 checkpoint,而在推理和部署阶段,使用TorchScript
导出的模型更加轻量化、跨平台,并且优化了执行性能,适合高效的推理场景。