Tianshou强化学习框架学习-关于Tianshou中PPO策略的Batch参数

两个Batch参数的理解

最近尝试用Tianshou自定义Policy解决机器人强化学习问题,尝试先用PPO训练一个Baseline,由于采用视觉网络作为PreprocessNet,Actor和Critic的网络参数量很大,8G的显存大概只能使用8个的Batch做反向传播,这个场景下Minibatch的使用就很重要了。

然而实际使用中,我们会发现有两个关于Batchsize的定义,并且在一些模型较大情况下,可能会出现不管如何调整Batch,显存依旧溢出的情况,详情见下文。

从Tianshou的A2C基类中可以看到其入口参数处有参数max_batchsize:
 

class A2CPolicy(PGPolicy):
...

    def __init__(
        self,
        actor: torch.nn.Module,
        critic: torch.nn.Module,
        optim: torch.optim.Optimizer,
...
        max_batchsize: int = 256,
        **kwargs: Any
    ) -> None:

而Tianshou框架训练时使用的Trainer,同样包含入口参数batch_size:

class OnpolicyTrainer(BaseTrainer):
...

    def __init__(
        self,
        policy: BasePolicy,
        train_collector: Collector,
...
        batch_size: int,
        step_per_collect: Optional[int] = None,
...

    ):

那么这两个batchsize有何区别?
实际上通过调试我们可以发现,Policy定义的max_batchsize实际上是控制策略(Policy)在前向推断(生成动作)时采用的Minibatch大小,例如PPO中计算return的方法中,我们需要通过并发的方式进行前向计算,这里如果一次性处理整个Batch中传来的尺寸,显卡可能会吃不消:
 

 def _compute_returns(
        self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
    ) -> Batch:
        v_s, v_s_ = [], []
        with torch.no_grad():
            for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
                v_s.append(self.critic(minibatch.obs))
                v_s_.append(self.critic(minibatch.obs_next))
...

而Trainer的batch_size,自然就是反向传播时分割Batch的Minibatch大小
 

    def learn(  # type: ignore
        self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
    ) -> Dict[str, List[float]]:
        losses, actor_losses, vf_losses, ent_losses = [], [], [], []
        for _ in range(repeat):
            for minibatch in batch.split(batch_size, merge_last=True):
                # calculate loss for actor
                dist = self(minibatch).dist
                log_prob = dist.log_prob(minibatch.act)
...

存在的BUG和处理

前面提到在使用PPO时会遇到在一些模型较大情况下,不管如何调整Batch,显存依旧溢出的情况。这个原因实际上是由于PPO在训练策略时会自动调用其前处理方法:
 

    def process_fn(
        self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
    ) -> Batch:
        if self._recompute_adv:
            # buffer input `buffer` and `indices` to be used in `learn()`.
            self._buffer, self._indices = buffer, indices
        batch = self._compute_returns(batch, buffer, indices)
        batch.act = to_torch_as(batch.act, batch.v_s)
        with torch.no_grad():
            batch.logp_old = self(batch).dist.log_prob(batch.act)
        return batch

可以看到,这里是没有分割Minibatch的!

如何解决?我认为如果有跟我一样使用较大规模模型的同学肯定会遇到这个情况,实际上我们会发现在https://github.com/thu-ml/tianshou/pull/1168的pull中已经修复了这个问题,但是为什么我Install的源码依旧错误?实际上Tianshou在多个版本迭代中修复了大量BUG,直接clone tianshou的仓库并本地安装到虚拟环境即可修复,这里可能会出现一些python版本不兼容的问题,可以直接通过Pip解决。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值