def edge_aware_loss_v2(rgb, disp,skymask=None):
"""Computes the smoothness loss for a disparity image
The color image is used for edge-aware smoothness
"""
mean_disp = disp.mean(1, True).mean(2, True)#行&列的均值
disp = disp / (mean_disp + 1e-7)#归一化处理
grad_disp_x = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])#x轴梯度
grad_disp_y = torch.abs(disp[:, :-1, :, :] - disp[:, 1:, :, :]))#y轴梯度
grad_rgb_x = torch.mean(torch.abs(rgb[:, :, :-1, :] - rgb[:, :, 1:, :]), 3, keepdim=True)
grad_rgb_y = torch.mean(torch.abs(rgb[:, :-1, :, :] - rgb[:, 1:, :, :]), 3, keepdim=True)
grad_disp_x *= torch.exp(-grad_rgb_x)
grad_disp_y *= torch.exp(-grad_rgb_y)
# mask=torch.ones_like(disp)
if skymask is not None:
grad_disp_x+=skymask[:,:,:-1,:]*grad_disp_x
grad_disp_y+=skymask[:,:-1,:,:]*grad_disp_y
return grad_disp_x.mean() + grad_disp_y.mean()
参考链接:边缘感知平滑损失