1 显存变大的原因
PyTorch采用动态图机制,通过tensor(以前是variable)来构建图,tensor里面包含的梯度信息用于反向传播求导。但不是所有变量都应该包含梯度(毕竟东西多,占“面积”就多),否则就会造成网络越跑,所占显存越大的情况,那怎么办呢?
2 loss.item()和loss.detach()解决问题
先看一段代码,
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_func(outputs, targets)
loss.backward()
optimizer.step()
print('loss:',loss)
print('loss.item():',loss.item())
print('loss.detach():',loss.detach())
train_loss += loss.item() # <----关键
...
输出:
loss: tensor(2.3391, device='cuda:0', grad_fn=<NllLossBackward>)
loss.item(): 2.3391051292419434
loss.detach(): tensor(2.3391, device='cuda:0')
很明显,loss.backward()在上面已经进行过了,下面去计算train_loss的时候就不要再带有梯度信息才合适。故有两种解决方案:
- 使用
loss.detach()
来获取不需要梯度回传的部分。
detach()通过重新声明一个变量,指向原变量的存放位置,但是requires_grad变为False。 - 使用
loss.item()
直接获得对应的python数据类型。
建议: 把除了loss.backward()之外的loss调用都改成loss.item()
3 感谢链接
https://www.zhihu.com/question/67209417/answer/344752405