关于模型训练中显存占用过大的或直接报显存爆炸的解决方法

本文讲述了如何在模型训练中避免验证阶段的显存爆炸问题,重点在于理解梯度、反向传播和显存管理。作者通过调整代码,发现loss.data.item()的使用显著减少了显存占用,并揭示了验证代码中未采取类似优化措施导致的问题。最终,通过在验证阶段使用no_grad()来避免梯度计算,成功解决了显存问题。
摘要由CSDN通过智能技术生成

模型训练显存爆炸解决方法

在模型训练中,应该理解梯度、反向传播、图层、显存这些概念,在模型训练过程中,一般会分为训练+验证+测试 ,在这些过程中,一般在训练过程中会比较占用显存,因为涉及到反向传播,需要大量的梯度,这些数据又存放在显存中。
在今天模型的训练中,突然发现可以训练,但是在验证过程中出现显存爆炸炸,提示我显存不足,我就很纳闷,一直在找问题,终于发现了:

在我的训练代码中:

   for epoch in range(0, epoch_num):
        net.train()
        for i, data in enumerate(train_dataloader):
            ite_num = ite_num + 1
            inputs, labels = data['image'], data['maskl']
            inputs = inputs.type(torch.FloatTensor)#注意第1行
            labels = labels.type(torch.FloatTensor)#注意第2行

            # wrap them in Variable:
            if torch.cuda.is_available():
                input, label = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)#注意第3行
            else:
                input, label = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)#注意第4行

            optimizer.zero_grad() 
            pred = net(input)
            loss = dice_bce_loss_fusion(pred,label)
            loss.backward() 
            optimizer.step() 
            running_loss += loss.data.item()#注意第5行
            del d0,loss 

**分宜以上代码,我通过print(loss.requires_grad)发现loss是有梯度的,但是我在累加的时候用了loss.data.item(),这样就减小了显存的占用量,且使用了del loss,但是我的验证代码中,没有使用loss.data.item()**和d0,导致在验证时出现显存爆炸。
在模型训练过程中,应该知道训练和验证的交替:

# evaluate model:
model.eval()#切换到验证
with torch.no_grad():#注意这一行,使得在内的所有参数均没有梯度,加快模型的训练与验证
    ...
    out_data = model(data)#数据是没有梯度的

# training step
model.train()#恢复模型训练
    ...

以上就是解决在验证时报出显存爆炸的解决方法,特此记录一波。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

孤鸟的歌

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值