[pytorch] 训练节省显存的技巧

参考资料:

  • https://blog.csdn.net/Wenyuanbo/article/details/119107466

实践如下:

1. AdamW 优化器比 SGD 优化器更耗显存

  • 有时候能够降几GB,但是有时候不怎么降

2. 删除多余无用变量

del 功能是彻底删除一个变量,要再使用必须重新创建,注意 del 删除的是一个变量而不是从内存中删除一个数据,这个数据有可能也被别的变量在引用,实现方法很简单,比如:

  • 残差链接是可以删除的嘛?这个我没有实践过,我怕删了计算不了梯度
 def forward(self, x):
 
	input_ = x
	x = F.relu_(self.conv1(x) + input_)
	x = F.relu_(self.conv2(x) + input_)
	x = F.relu_(self.conv3(x) + input_)
	
	del input_  # 删除变量 input_

	x = self.conv4(x)  # 输出层
	return x

  • 训练完之后删除loss释放计算图,不然loss和下一个batch重叠在一起,很消耗显存,实测可以下降9GB左右显存,可能是我的模型有点大。
total_loss /= self.acc_step
total_loss.backward()
if i % self.acc_step == 0 or i == len(train_loader):  # i starts from 1
    if self.args.use_amp:
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
    else:
        self.optimizer.step()
        self.optimizer.zero_grad()

del x, y, out, total_loss, losses   #  batch 可以删掉,同时 loss也可以删掉

3. 周期性地清空缓存

使用 torch.cuda.empty_cache() 释放显存。

这个我也没有试验过。

4. 清空计算图

loss.backward 是计算梯度, optimzer.step 是更新梯度,但是计算图没有释放,可以使用zero_grad释放计算图。

myNet.zero_grad()  # 模型参数梯度清零
optimizer.zero_grad()  # 优化器参数梯度清零

5. 半精度学习 / 混合精度学习

因为偷懒,使用的是 pytorch 自带的 scaler,发现其实没什么用

if self.args.use_amp:
            self.scaler = GradScaler()

if self.args.use_amp:
   with autocast():
         out = self.model(x)
         total_loss, losses = self.model.loss(out, y)
         total_loss = self.scaler.scale(total_loss)
 else:
     out = self.model(x)
     total_loss, losses = self.model.loss(out, y) 

total_loss /= self.acc_step
 total_loss.backward()
 if i % self.acc_step == 0 or i == len(train_loader):  # i starts from 1
     if self.args.use_amp:
         self.scaler.step(self.optimizer)
         self.scaler.update()
         self.optimizer.zero_grad()
     else:
         self.optimizer.step()
         self.optimizer.zero_grad()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值