怀疑某个loss(lossp)造成崩溃(cam全为0),尝试从总loss剔除,结果发现一旦计算lossp就会崩溃(即使没参加backward),不计算就不会,将下述第一个代码段改成第二个就ok了,具体关注cam_little,包括进入计算lossp的部分
no——ok
def forward(self, x, label,need_inter=False,need_p=True):
# print("1=",meminfo.free/1024**3) #已用显存大小
scale_factor = 0.3
N, C, H, W = x.size()
d = super().forward_as_dict(x)
feature=self.dropout7(d['conv6'])
cam_little = self.fc8(feature)
cam=cam_little.clone()
n,c,h,w = cam.size()
with torch.no_grad():
cam_d = F.relu(cam.detach())
cam_d_max = torch.max(cam_d.view(n,c,-1), dim=-1)[0].view(n,c,1,1)+1e-5
cam_d_norm = F.relu(cam_d-1e-5)/cam_d_max
cam_d_norm[:,0,:,:] = 1-torch.max(cam_d_norm[:,1:,:,:], dim=1)[0]
cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0]
cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0
f8_3 = F.relu(self.f8_3(d['conv4'].detach()), inplace=True)
f8_4 = F.relu(self.f8_4(d['conv5'].detach()), inplace=True)
x_s = F.interpolate(input=x,size=(h,w),mode='bilinear',align_corners=True,recompute_scale_factor=True)
f = torch.cat([x_s, f8_3, f8_4], dim=1)
n,c,h,w = f.size()
cam = F.interpolate(input=cam, size=(H,W), mode='bilinear', align_corners=True,recompute_scale_factor=True)
cam_rv = F.interpolate(input=self.PCM(cam_d_norm, f), size=(H,W), mode='bilinear', align_corners=True,recompute_scale_factor=True)
# cam_rv=cam
if(need_p):
loss_p=self.pNet(feature,cam_little[:,1:,...].detach(),label[:,1:,...])
else:
loss_p=torch.tensor(0,dtype=torch.float64).cuda(x.device)
label1 = F.adaptive_avg_pool2d(cam_little, (1,1))
loss_rvmin = adaptive_min_pooling_loss((cam_rv*label)[:,1:,:,:])
if(need_inter):
cam = F.interpolate(input=visualization.max_norm(cam),scale_factor=scale_factor,mode='bilinear',align_corners=True,recompute_scale_factor=True)*label
cam_rv = F.interpolate(input=visualization.max_norm(cam_rv),scale_factor=scale_factor,mode='bilinear',align_corners=True,recompute_scale_factor=True)*label
else:
cam = visualization.max_norm(cam)*label
cam_rv = visualization.max_norm(cam_rv)*label
loss_cls = F.multilabel_soft_margin_loss(label1[:,1:,:,:], label[:,1:,:,:])
return cam, cam_rv,loss_cls,loss_rvmin,loss_p
ok
def forward(self, x, label,need_inter=False,need_p=True):
# print("1=",meminfo.free/1024**3) #已用显存大小
scale_factor = 0.3
N, C, H, W = x.size()
d = super().forward_as_dict(x)
feature=self.dropout7(d['conv6'])
cam = self.fc8(feature)
cam_little=cam.clone().detach()
n,c,h,w = cam.size()
with torch.no_grad():
cam_d = F.relu(cam.detach())
cam_d_max = torch.max(cam_d.view(n,c,-1), dim=-1)[0].view(n,c,1,1)+1e-5
cam_d_norm = F.relu(cam_d-1e-5)/cam_d_max
cam_d_norm[:,0,:,:] = 1-torch.max(cam_d_norm[:,1:,:,:], dim=1)[0]
cam_max = torch.max(cam_d_norm[:,1:,:,:], dim=1, keepdim=True)[0]
cam_d_norm[:,1:,:,:][cam_d_norm[:,1:,:,:] < cam_max] = 0
f8_3 = F.relu(self.f8_3(d['conv4'].detach()), inplace=True)
f8_4 = F.relu(self.f8_4(d['conv5'].detach()), inplace=True)
x_s = F.interpolate(input=x,size=(h,w),mode='bilinear',align_corners=True,recompute_scale_factor=True)
f = torch.cat([x_s, f8_3, f8_4], dim=1)
n,c,h,w = f.size()
cam = F.interpolate(input=cam, size=(H,W), mode='bilinear', align_corners=True,recompute_scale_factor=True)
cam_rv = F.interpolate(input=self.PCM(cam_d_norm, f), size=(H,W), mode='bilinear', align_corners=True,recompute_scale_factor=True)
# cam_rv=cam
if(need_p):
loss_p=self.pNet(feature,cam_little[:,1:,...],label[:,1:,...])
else:
loss_p=torch.tensor(0,dtype=torch.float64).cuda(x.device)
label1 = F.adaptive_avg_pool2d(cam_little, (1,1))
loss_rvmin = adaptive_min_pooling_loss((cam_rv*label)[:,1:,:,:])
if(need_inter):
cam = F.interpolate(input=visualization.max_norm(cam),scale_factor=scale_factor,mode='bilinear',align_corners=True,recompute_scale_factor=True)*label
cam_rv = F.interpolate(input=visualization.max_norm(cam_rv),scale_factor=scale_factor,mode='bilinear',align_corners=True,recompute_scale_factor=True)*label
else:
cam = visualization.max_norm(cam)*label
cam_rv = visualization.max_norm(cam_rv)*label
loss_cls = F.multilabel_soft_margin_loss(label1[:,1:,:,:], label[:,1:,:,:])
return cam, cam_rv,loss_cls,loss_rvmin,loss_p
有时间看看为啥