疑惑。。。

怀疑某个loss(lossp)造成崩溃(cam全为0),尝试从总loss剔除,结果发现一旦计算lossp就会崩溃(即使没参加backward),不计算就不会,将下述第一个代码段改成第二个就ok了,具体关注cam_little,包括进入计算lossp的部分

no——ok

	def forward(self, x, label,need_inter=False,need_p=True):
        # print("1=",meminfo.free/1024**3)  #已用显存大小
        scale_factor = 0.3
        N, C, H, W = x.size()
        d = super().forward_as_dict(x)
        feature=self.dropout7(d['conv6'])
        cam_little = self.fc8(feature)
        cam=cam_little.clone()
        n,c,h,w = cam.size()
        with torch.no_grad():
            cam_d = F.relu(cam.detach())
            cam_d_max = torch.max(cam_d.view(n,c,-1), dim=-1)[0].view(n,c,1,1)+1e-5
            cam_d_norm = F.relu(cam_d-1e-5)/cam_d_max
            cam_d_norm[:,0,:,:] = 1-torch.max(cam_d_norm[:,1:,:,:], dim=1)[0]
            cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0]
            cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0

        f8_3 = F.relu(self.f8_3(d['conv4'].detach()), inplace=True)
        f8_4 = F.relu(self.f8_4(d['conv5'].detach()), inplace=True)
        x_s = F.interpolate(input=x,size=(h,w),mode='bilinear',align_corners=True,recompute_scale_factor=True)
        f = torch.cat([x_s, f8_3, f8_4], dim=1)
        n,c,h,w = f.size()
        cam = F.interpolate(input=cam, size=(H,W), mode='bilinear', align_corners=True,recompute_scale_factor=True)
        cam_rv = F.interpolate(input=self.PCM(cam_d_norm, f), size=(H,W), mode='bilinear', align_corners=True,recompute_scale_factor=True)
        # cam_rv=cam
        if(need_p):
            loss_p=self.pNet(feature,cam_little[:,1:,...].detach(),label[:,1:,...])
        else:
            loss_p=torch.tensor(0,dtype=torch.float64).cuda(x.device)
        label1 = F.adaptive_avg_pool2d(cam_little, (1,1))
        loss_rvmin = adaptive_min_pooling_loss((cam_rv*label)[:,1:,:,:])
        if(need_inter):
            cam = F.interpolate(input=visualization.max_norm(cam),scale_factor=scale_factor,mode='bilinear',align_corners=True,recompute_scale_factor=True)*label
            cam_rv = F.interpolate(input=visualization.max_norm(cam_rv),scale_factor=scale_factor,mode='bilinear',align_corners=True,recompute_scale_factor=True)*label
        else:
            cam = visualization.max_norm(cam)*label
            cam_rv = visualization.max_norm(cam_rv)*label
        loss_cls = F.multilabel_soft_margin_loss(label1[:,1:,:,:], label[:,1:,:,:])
        return cam, cam_rv,loss_cls,loss_rvmin,loss_p

ok

    def forward(self, x, label,need_inter=False,need_p=True):
        # print("1=",meminfo.free/1024**3)  #已用显存大小
        scale_factor = 0.3
        N, C, H, W = x.size()
        d = super().forward_as_dict(x)
        feature=self.dropout7(d['conv6'])
        cam = self.fc8(feature)
        cam_little=cam.clone().detach()
        n,c,h,w = cam.size()
        with torch.no_grad():
            cam_d = F.relu(cam.detach())
            cam_d_max = torch.max(cam_d.view(n,c,-1), dim=-1)[0].view(n,c,1,1)+1e-5
            cam_d_norm = F.relu(cam_d-1e-5)/cam_d_max
            cam_d_norm[:,0,:,:] = 1-torch.max(cam_d_norm[:,1:,:,:], dim=1)[0]
            cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0]
            cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0

        f8_3 = F.relu(self.f8_3(d['conv4'].detach()), inplace=True)
        f8_4 = F.relu(self.f8_4(d['conv5'].detach()), inplace=True)
        x_s = F.interpolate(input=x,size=(h,w),mode='bilinear',align_corners=True,recompute_scale_factor=True)
        f = torch.cat([x_s, f8_3, f8_4], dim=1)
        n,c,h,w = f.size()
        cam = F.interpolate(input=cam, size=(H,W), mode='bilinear', align_corners=True,recompute_scale_factor=True)
        cam_rv = F.interpolate(input=self.PCM(cam_d_norm, f), size=(H,W), mode='bilinear', align_corners=True,recompute_scale_factor=True)
        # cam_rv=cam
        if(need_p):
            loss_p=self.pNet(feature,cam_little[:,1:,...],label[:,1:,...])
        else:
            loss_p=torch.tensor(0,dtype=torch.float64).cuda(x.device)
        label1 = F.adaptive_avg_pool2d(cam_little, (1,1))
        loss_rvmin = adaptive_min_pooling_loss((cam_rv*label)[:,1:,:,:])
        if(need_inter):
            cam = F.interpolate(input=visualization.max_norm(cam),scale_factor=scale_factor,mode='bilinear',align_corners=True,recompute_scale_factor=True)*label
            cam_rv = F.interpolate(input=visualization.max_norm(cam_rv),scale_factor=scale_factor,mode='bilinear',align_corners=True,recompute_scale_factor=True)*label
        else:
            cam = visualization.max_norm(cam)*label
            cam_rv = visualization.max_norm(cam_rv)*label
        loss_cls = F.multilabel_soft_margin_loss(label1[:,1:,:,:], label[:,1:,:,:])
        return cam, cam_rv,loss_cls,loss_rvmin,loss_p

有时间看看为啥

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值