Pytorch堆叠多个损失造成内存爆炸

这几天跑代码的时候,跑着跑着就显示被killed掉(整个人都不好了)。查系统日志发现是内存不够(out of memory),没……办……法……了,直接放弃!当然这是不可能的,笔者怎么可能是个轻言放弃的人呢,哈哈。

言归正传,笔者用的设备是3T的硬盘,跑的程序batch_size=1024,共划分了2000个batch,每跑一个batch内存占用率就会升高,0.5%左右,无奈之下只能一句一句debug,最后后发现是损失累加造成的,如下放所示,代码共计算了三个损失:BPR Loss , Reg Loss , InfoNCE Loss,不能直接累加作为total_loss!而是通过.item()将损失值取出,再累加。

# BPR Loss
bpr_loss = -torch.sum(F.logsigmoid(sup_logits)) 

# Reg Loss
reg_loss = l2_loss(
self.lightgcn.user_embeddings(bat_users),
self.lightgcn.item_embeddings(bat_pos_items),
self.lightgcn.item_embeddings(bat_neg_items),
)
                
# InfoNCE Loss
clogits_user = torch.logsumexp(ssl_logits_user / self.ssl_temp, dim=1)
clogits_item = torch.logsumexp(ssl_logits_item / self.ssl_temp, dim=1)
infonce_loss = torch.sum(clogits_user + clogits_item)
    
loss = bpr_loss + self.ssl_reg * infonce_loss + self.reg * reg_loss

total_loss = total_loss + loss.item() 
total_bpr_loss += bpr_loss.item()
total_reg_loss += self.reg * reg_loss.item()
               
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值