PPO代码解读【gymRL】

这篇文章选取经典月球登陆的小游戏的代码来学习PPO代码的实现

在这里插入图片描述

首先定义一些参数

class Config(BasicConfig):
    def __init__(self):
        super(Config, self).__init__()
        self.env_name = 'LunarLander-v2' 
        self.render_mode = 'rgb_array'
        self.algo_name = 'PPO'
        self.train_eps = 2000
        self.batch_size = 1024
        self.mini_batch = 64
        self.epochs = 10
        self.clip = 0.2 # /epsilon
        self.gamma = 0.995 #折扣因子
        self.dual_clip = 3.0 
        self.val_coef = 0.5 #价值函数损失的系数,用于平衡策略损失和价值函数损失。在PPO中,损失函数通常包含策略损失和价值函数损失,这个系数用于调节两者的相对重要性。
        self.lr_start = 5e-4
        self.lr_end = 5e-5
        self.ent_coef = 1e-2 #熵惩罚系数,用于鼓励策略的探索性。较高的熵系数会鼓励策略更多地探索不同的动作,从而避免陷入局部最优解
        self.grad_clip = 0.5 #梯度剪切的阈值
        self.load_model = True

然后定义动作价值网络

class ActorCritic(nn.Module):
    def __init__(self, cfg):
        super(ActorCritic, self).__init__()
        self.fc_head = PSCN(cfg.n_states, 256)
        self.actor_fc = MLP([256, 64, cfg.n_actions])
        self.critic_fc = MLP([256, 32, 1])

    def forward(self, s):
        x = self.fc_head(s)
        prob = F.softmax(self.actor_fc(x), dim=-1)
        value = self.critic_fc(x)
        return prob, value

MLP定义如下,只需要注意dim_list是一个列表,表示每一层的维度。例如 [input_dim, hidden_dim1, hidden_dim2, …, output_dim]

class MLP(nn.Module):
    def __init__(self,
                 dim_list,
                 activation=nn.PReLU(),
                 last_act=False,
                 use_norm=False,
                 linear=nn.Linear,
                 *args, **kwargs
                 ):
        super(MLP, self).__init__()
        assert dim_list, "Dim list can't be empty!"
        layers = []
        for i in range(len(dim_list) - 1):
            layer = initialize_weights(linear(dim_list[i], dim_list[i + 1], *args, **kwargs))
            layers.append(layer)
            if i < len(dim_list) - 2:
                if use_norm:
                    layers.append(nn.LayerNorm(dim_list[i + 1]))
                layers.append(activation)
        if last_act:
            if use_norm:
                layers.append(nn.LayerNorm(dim_list[-1]))
            layers.append(activation)
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)

所以actor是256->64->动作数,critic是256->32->1

定义PPO类

class PPO(ModelLoader):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.cfg = cfg
        self.net = torch.jit.script(ActorCritic(cfg).to(cfg.device))
        self.optimizer = optim.Adam(self.net.parameters(), lr=cfg.lr_start, eps=1e-5, amsgrad=True)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=cfg.train_eps, eta_min=cfg.lr_end) #学习率调度器
        self.memory = ReplayBuffer(cfg)
        self.learn_step = 0
        self.scaler = GradScaler()

选择动作函数

    @torch.no_grad()
    def choose_action(self, state):
        state = torch.tensor(state, device=self.cfg.device, dtype=torch.float).unsqueeze(0)
        prob, value = self.net(state)
        dist = Categorical(prob)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob.item(), value.item()

评估时选择动作

    @torch.no_grad()
    def evaluate(self, state):
        state = torch.tensor(state, device=self.cfg.device, dtype=torch.float).unsqueeze(0)
        prob, _ = self.net(state)
        m = Categorical(prob)
        action = m.probs.argmax().item() #选择最大的
        return action

重点update!

    def update(self):
        states, actions, old_probs, adv, v_target = self.memory.sample() # 获取经验池中的数据
        losses = np.zeros(5) 

        for _ in range(self.cfg.epochs): 
            for indices in BatchSampler(SubsetRandomSampler(range(self.memory.size())), self.cfg.mini_batch, drop_last=False):# 随机取小批量数据
                with autocast():# 混合精度训练
                    actor_prob, value = self.net(states[indices]) # 获取策略网络的输出
                    log_probs = torch.log(actor_prob.gather(1, actions[indices])) # 计算动作概率分布的对数概率,也是新策略的对数概率
                    ratio = torch.exp(log_probs - old_probs[indices])
                    """
                    这个比值用于计算优势函数(advantage function)的估计,并用于更新策略网络。
                    这样可以确保策略更新不会偏离当前策略太远,从而提高训练的稳定性和效率
                    """
                    
                    surr1 = ratio * adv[indices] # r(θ)*优势函数
                    surr2 = torch.clamp(ratio, 1 - self.cfg.clip, 1 + self.cfg.clip) * adv[indices] #截断

                    min_surr = torch.min(surr1, surr2) #min
                    clip_loss = -torch.mean(torch.where(
                        adv[indices] < 0,
                        torch.max(min_surr, self.cfg.dual_clip * adv[indices]),
                        min_surr
                    ))
                    """
                    计算优势函数(adv)小于0时的最小替代损失(min_surr)与配置项dual_clip乘以优势函数的较大值
                    使用配置项 dual_clip 可以限制损失函数的下界,防止过拟合或梯度异常,使得训练过程更加平滑
                    """
                    value_loss = F.mse_loss(v_target[indices], value)
                    entropy_loss = -torch.mean(-torch.sum(actor_prob * torch.log(actor_prob), dim=1))
                    loss = clip_loss + self.cfg.val_coef * value_loss + self.cfg.ent_coef * entropy_loss

                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                nn.utils.clip_grad_norm_(self.net.parameters(), self.cfg.grad_clip)
                self.scaler.step(self.optimizer)
                self.scaler.update()

                losses[0] += loss.item()
                losses[1] += clip_loss.item()
                losses[2] += value_loss.item()
                losses[3] += entropy_loss.item()
                
        self.scheduler.step()
        self.memory.clear()
        self.learn_step += 1

        return {
            'total_loss': losses[0] / self.cfg.epochs,
            'clip_loss': losses[1] / self.cfg.epochs,
            'value_loss': losses[2] / self.cfg.epochs,
            'entropy_loss': losses[3] / self.cfg.epochs / (self.cfg.batch_size // cfg.mini_batch),
            'advantage': adv.mean().item(),
            'lr': self.optimizer.param_groups[0]['lr'],
        }

在这段代码中,有多个损失函数(loss)用于不同的目的。这些损失函数的设计旨在优化强化学习算法(特别是PPO)的性能。以下是每个损失函数及其对应关系的详细解释:

1. Clip Loss (clip_loss)

  • 代码部分:

    surr1 = ratio * adv[indices]
    surr2 = torch.clamp(ratio, 1 - self.cfg.clip, 1 + self.cfg.clip) * adv[indices]
    min_surr = torch.min(surr1, surr2)
    clip_loss = -torch.mean(torch.where(
        adv[indices] < 0,
        torch.max(min_surr, self.cfg.dual_clip * adv[indices]),
        min_surr
    ))
    
  • 解释:

    • clip_loss 是PPO的核心损失函数之一。PPO通过限制策略更新的幅度,防止策略发生剧烈的变化。它通过计算两个损失 surr1surr2 来实现:
      • surr1 是未裁剪的损失,计算方式是策略比率 ratio 与优势函数 adv 的乘积。
      • surr2 是裁剪后的损失,通过 torch.clamp 限制策略比率在 [1 - clip, 1 + clip] 范围内。
    • 然后,通过 torch.min 选择 surr1surr2 的较小值,以确保策略更新不会偏离旧策略太远。
    • torch.where 用于进一步处理当 adv 小于0时,通过 dual_clip 限制损失的下界,这有助于防止过拟合或梯度异常。
    • 最终 clip_loss 是通过取负均值计算得出。

2. Value Loss (value_loss)

  • 代码部分:

    value_loss = F.mse_loss(v_target[indices], value)
    
  • 解释:

    • value_loss 是用于优化价值网络的损失函数。它通过最小化预测的价值 value 与目标价值 v_target 之间的均方误差(MSE)来进行优化。
    • 价值网络旨在估计状态的价值函数,value_loss 的目标是使网络的估计更加准确。

3. Entropy Loss (entropy_loss)

  • 代码部分:

    entropy_loss = -torch.mean(-torch.sum(actor_prob * torch.log(actor_prob), dim=1))
    
  • 解释:

    • entropy_loss 用于鼓励策略的探索性。它通过计算策略的熵(entropy)来实现,熵越大,说明策略越倾向于探索不同的动作,而不是集中在某一个动作上。
    • 在强化学习中,高熵可以避免策略过早收敛到局部最优解。这个损失通常是负值(因为我们希望最大化熵),因此在总损失中它会带来一定的正向贡献(通过 self.cfg.ent_coef 进行加权)。

4. 总损失 (loss)

  • 代码部分:

    loss = clip_loss + self.cfg.val_coef * value_loss + self.cfg.ent_coef * entropy_loss
    
  • 解释:

    • loss 是最终的总损失函数,它由三个部分组成:
      • clip_loss: 控制策略更新的稳定性。
      • value_loss: 确保价值函数的估计准确性。
      • entropy_loss: 鼓励策略的探索性。
    • 这些损失分别乘以相应的系数(self.cfg.val_coefself.cfg.ent_coef),以调整它们在总损失中的权重。

梯度计算和优化步骤:

  • 代码部分:

    self.optimizer.zero_grad()
    self.scaler.scale(loss).backward()
    nn.utils.clip_grad_norm_(self.net.parameters(), self.cfg.grad_clip)
    self.scaler.step(self.optimizer)
    self.scaler.update()
    
    • 解释:
      • 这里,首先将累积的梯度清零,然后计算 loss 的梯度并进行反向传播(backward())。
      • 使用 nn.utils.clip_grad_norm_ 对梯度进行裁剪,防止梯度爆炸。
      • 然后通过 self.optimizer.step() 更新模型参数。
      • self.scaler.update() 是在混合精度训练中用于更新缩放因子(scaling factor)。

损失记录:

  • 代码部分:

    losses[0] += loss.item()
    losses[1] += clip_loss.item()
    losses[2] += value_loss.item()
    losses[3] += entropy_loss.item()
    
    • 解释:
      • 这些行代码将每个损失值累加到相应的 losses 列表中,用于后续的损失统计和分析。losses[0] 对应 losslosses[1] 对应 clip_losslosses[2] 对应 value_losslosses[3] 对应 entropy_loss

学习率调整:

  • 代码部分:

    self.scheduler.step()
    
    • 解释:
      • 通过 self.scheduler.step() 调整学习率,通常用于执行预先定义的学习率调度策略,以便在训练过程中动态调整学习率。

内存清理和步骤更新:

  • 代码部分:

    self.memory.clear()
    self.learn_step += 1
    
    • 解释:
      • 清空存储的经验数据(self.memory.clear())为下一个训练周期做准备,同时增加学习步数计数器(self.learn_step += 1),通常用于控制训练进度或记录学习进程。

主函数:

if __name__ == '__main__':
    BenchMark.train(PPO, Config)

下一篇讲train😇😇

  • 8
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch PPO(Proximal Policy Optimization)是一种常用的强化学习算法,用于训练策略模型。下面是一个简单的PyTorch PPO代码示例: ```python import torch import torch.nn as nn import torch.optim as optim import gym # 定义策略网络 class Policy(nn.Module): def __init__(self, input_dim, output_dim): super(Policy, self).__init__() self.fc = nn.Linear(input_dim, 64) self.actor = nn.Linear(64, output_dim) self.critic = nn.Linear(64, 1) def forward(self, x): x = torch.relu(self.fc(x)) action_probs = torch.softmax(self.actor(x), dim=-1) state_value = self.critic(x) return action_probs, state_value # 定义PPO算法 class PPO: def __init__(self, input_dim, output_dim): self.policy = Policy(input_dim, output_dim) self.optimizer = optim.Adam(self.policy.parameters(), lr=0.001) def select_action(self, state): state = torch.FloatTensor(state) action_probs, _ = self.policy(state) action_dist = torch.distributions.Categorical(action_probs) action = action_dist.sample() return action.item() def update(self, states, actions, log_probs, returns, advantages): states = torch.FloatTensor(states) actions = torch.LongTensor(actions) log_probs = torch.FloatTensor(log_probs) returns = torch.FloatTensor(returns) advantages = torch.FloatTensor(advantages) # 计算策略损失和价值损失 action_probs, state_values = self.policy(states) dist = torch.distributions.Categorical(action_probs) new_log_probs = dist.log_prob(actions) ratio = torch.exp(new_log_probs - log_probs) surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1-0.2, 1+0.2) * advantages policy_loss = -torch.min(surr1, surr2).mean() value_loss = nn.MSELoss()(state_values, returns.unsqueeze(1)) # 更新策略网络 self.optimizer.zero_grad() loss = policy_loss + 0.5 * value_loss loss.backward() self.optimizer.step() # 创建环境和PPO对象 env = gym.make('CartPole-v1') input_dim = env.observation_space.shape output_dim = env.action_space.n ppo = PPO(input_dim, output_dim) # 训练PPO模型 max_episodes = 1000 max_steps = 200 for episode in range(max_episodes): state = env.reset() states, actions, log_probs, rewards = [], [], [], [] for step in range(max_steps): action = ppo.select_action(state) next_state, reward, done, _ = env.step(action) states.append(state) actions.append(action) log_probs.append(torch.log(ppo.policy(torch.FloatTensor(state))[action])) rewards.append(reward) state = next_state if done: break # 计算回报和优势值 returns = [] advantages = [] G = 0 for r in reversed(rewards): G = r + 0.99 * G returns.insert(0, G) returns = torch.tensor(returns) returns = (returns -9) for t in range(len(rewards)): advantages.append(returns[t] - ppo.policy(torch.FloatTensor(states[t]))) advantages = torch.tensor(advantages) # 更新策略网络 ppo.update(states, actions, log_probs, returns, advantages) # 使用训练好的模型进行测试 state = env.reset() total_reward = 0 while True: env.render() action = ppo.select_action(state) state, reward, done, _ = env.step(action) total_reward += reward if done: break print("Total reward:", total_reward) ``` 这个示例代码使用PyTorch实现了一个简单的PPO算法,用于在CartPole-v1环境中训练一个策略模型。代码中包含了策略网络的定义、PPO算法的实现以及训练和测试的过程。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值