出现了好几次这个问题,一直不知道该怎么处理,作为小白的我,终于在刚刚发现,居然是因为每个batch没有清理显存,导致一直占用着,主要是因为我没有在每个batch进行
optimizer.zero_grad()
loss.backward()
optimizer.step()
这几步
梯度没有清零:在每个训练批次之前,需要将模型的梯度归零,以避免梯度累积。否则,梯度会累积在计算图中,导致显存无法释放。确保在每个批次开始之前,使用optimizer.zero_grad()清零梯度。
我变成在算是在一个epoch结束后再去进行梯度清零了,所以一下子就爆了