pytorch retain_graph=True 训练导致GPU显存泄漏 OOM (out of memory)

在PyTorch训练过程中遇到GPU显存不足的问题,即使设置小batch_size也无法解决。问题源于在backward函数中使用了retain_graph=True。此设置不仅降低训练速度,还会增加显存消耗。原因是不同loss未取均值导致累积的tensor过大。解决方案是在计算loss后立即调用.mean()转换为标量。
摘要由CSDN通过智能技术生成

现象:
训练过程中多个loss回传产生了GPU显存不够用的情况(即使是设置batch_size最小也不行),在backward函数中去掉retain_graph=True之后,情况没有出现。

注意:retain_graph=True是一个在训练代码中需要极力避免的一个设置。降低了训练速度,增大了显存开销。但是我的代码中暂时不能避免这个设置。

原因分析:
我这里出现这个情况的原因:因为不同loss求完之后没有算均值,可能返回的是一个tensor,在训练loop中会累积越来越大,要通过 .mean() 把它变成标量。

解决:

criterion = torch.nn.CrossEntropyLoss()
output = module_a(fc1Features
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值