最近在训练网络时发现网络训练了几个epoch之后就会出现OOM
一开始以为是内存不够,后来才发现是在网络训练过程中,显存会不断的增加。
针对以上的问题,查找资料总结了三种有用的方式
- 训练过程过程中,保存参数加.item()
原代码:
def train_one_epoch(
model, criterion, train_dataloader, optimizer, epoch, clip_max_norm
):
model.train()
device = next(model.parameters()).device
train_loss = 0
for i, d in enumerate(train_dataloader):
d = d.to(device)
optimizer.zero_grad()
out_net = model(d)
loss = criterion(out_net, d, epoch)
train_loss += loss
loss.backward()
if clip_max_norm > 0: