两个Batch参数的理解
最近尝试用Tianshou自定义Policy解决机器人强化学习问题,尝试先用PPO训练一个Baseline,由于采用视觉网络作为PreprocessNet,Actor和Critic的网络参数量很大,8G的显存大概只能使用8个的Batch做反向传播,这个场景下Minibatch的使用就很重要了。
然而实际使用中,我们会发现有两个关于Batchsize的定义,并且在一些模型较大情况下,可能会出现不管如何调整Batch,显存依旧溢出的情况,详情见下文。
从Tianshou的A2C基类中可以看到其入口参数处有参数max_batchsize:
class A2CPolicy(PGPolicy):
...
def __init__(
self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
...
max_batchsize: int = 256,
**kwargs: Any
) -> None:
而Tianshou框架训练时使用的Trainer,同样包含入口参数batch_size:
class OnpolicyTrainer(BaseTrainer):
...
def __init__(
self,
policy: BasePolicy,
train_collector: Collector,
...
batch_size: int,
step_per_collect: Optional[int] = None,
...
):
那么这两个batchsize有何区别?
实际上通过调试我们可以发现,Policy定义的max_batchsize实际上是
控制策略(Policy)在前向推断(生成动作)时采用的Minibatch大小,例如PPO中计算return的方法中,我们需要通过并发的方式进行前向计算,这里如果一次性处理整个Batch中传来的尺寸,显卡可能会吃不消:
def _compute_returns(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
v_s, v_s_ = [], []
with torch.no_grad():
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(minibatch.obs))
v_s_.append(self.critic(minibatch.obs_next))
...
而Trainer的batch_size,自然就是反向传播时分割Batch的Minibatch大小:
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
log_prob = dist.log_prob(minibatch.act)
...
存在的BUG和处理
前面提到在使用PPO时会遇到在一些模型较大情况下,不管如何调整Batch,显存依旧溢出的情况。这个原因实际上是由于PPO在训练策略时会自动调用其前处理方法:
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
if self._recompute_adv:
# buffer input `buffer` and `indices` to be used in `learn()`.
self._buffer, self._indices = buffer, indices
batch = self._compute_returns(batch, buffer, indices)
batch.act = to_torch_as(batch.act, batch.v_s)
with torch.no_grad():
batch.logp_old = self(batch).dist.log_prob(batch.act)
return batch
可以看到,这里是没有分割Minibatch的!
如何解决?我认为如果有跟我一样使用较大规模模型的同学肯定会遇到这个情况,实际上我们会发现在https://github.com/thu-ml/tianshou/pull/1168的pull中已经修复了这个问题,但是为什么我Install的源码依旧错误?实际上Tianshou在多个版本迭代中修复了大量BUG,直接clone tianshou的仓库并本地安装到虚拟环境即可修复,这里可能会出现一些python版本不兼容的问题,可以直接通过Pip解决。