图像融合常见损失函数总结

重建损失(MSE和MAE)

  • 均方误差(Mean Squared Error, MSE):

 其中,I_{i} 是源图像的像素值,\widehat{I_{i}}是融合图像的像素值,N 是图像的像素总数

作用:用于衡量源图像和融合图像之间的平均平方差异,强调较大的误差。目标是使融合图像尽可能接近源图像。常见用于图像的强度损失函数

平均绝对误差(Mean Absolute Error, MAE): 

在CDDfuse改进为 

L=\frac{1}{HW}\left \| I_{f}-max\left ( I_{vi},I_{ir} \right ) \right \|_{1}

 作用:用于衡量源图像和融合图像之间的平均绝对差异,对异常值不敏感,能平滑地反映误差。

 结构相似性损失(Structural Similarity Loss, SSIM):

 其中,𝑆𝑆𝐼𝑀(𝐼,𝐼^) 是源图像 I和融合图像\widehat{I}之间的结构相似性指数。

作用:用于衡量两幅图像之间的结构相似性,比MSE和MAE更能反映图像的感知质量。它考虑了亮度、对比度和结构等因素,对人眼感知更友好。 

全变分损失(Total Variation Loss, TV Loss): 

其中,\widehat{I_{i,j}} 是融合图像在位置 (𝑖,𝑗)的像素值。

作用:用于减少融合图像中的噪声和伪影,增强图像的平滑性。通过惩罚图像梯度的大变化,能够有效地减少图像中的不连续性和噪声。

 边缘保持损失(Edge Preservation Loss)

 其中,\triangledown I_{i,j} 和 \triangledown \widehat{I_{i,j}}分别是源图像和融合图像在位置 (𝑖,𝑗) 处的梯度。

作用:用于保持源图像中的边缘信息,确保融合图像中的边缘与源图像相似。这有助于在融合过程中保留重要的细节和轮廓信息。 

平滑损失函数:

𝑝是 I_{f}\phi _{f}中的像素位置索引,𝑅 表示 𝑝 的邻域集合,𝛼 是系数 

作用:利用双边滤波器,以融合图像避免过度平滑

代码实现:未验证仅供参考,过段时间验证,哈哈

import torch
import torch.nn.functional as F


#在这个实现中:
#bilateral_filter 函数计算图像中每个像素与其邻域像素之间的双边滤波器权重。
#smoothness_loss 函数计算基于双边滤波器的平滑损失,通过遍历每个像素及其邻域像素的变化来实现。
def bilateral_filter(image, neighborhood_radius=1):
    # image: Tensor of shape (B, C, H, W)
    B, C, H, W = image.size()
    smooth_loss = 0.0

    for i in range(-neighborhood_radius, neighborhood_radius + 1):
        for j in range(-neighborhood_radius, neighborhood_radius + 1):
            if i == 0 and j == 0:
                continue
            
            shifted_image = torch.roll(image, shifts=(i, j), dims=(2, 3))
            intensity_diff = torch.abs(image - shifted_image)
            spatial_weight = torch.exp(-intensity_diff * alpha)
            smooth_loss += torch.sum(spatial_weight * torch.abs(image - shifted_image))

    return smooth_loss / (B * C * H * W)

def smoothness_loss(deformation_field, fused_image, alpha=0.5):
    # deformation_field: Tensor of shape (B, 2, H, W) representing (dx, dy)
    # fused_image: Tensor of shape (B, C, H, W)
    B, C, H, W = fused_image.size()
    
    # Compute the gradients of the deformation field
    dx = deformation_field[:, 0, :, :]
    dy = deformation_field[:, 1, :, :]
    
    # Apply bilateral filter to the fused image
    bf = bilateral_filter(fused_image, neighborhood_radius=1)
    
    # Compute the smoothness loss
    smooth_loss = 0.0
    for p in range(H):
        for q in range(W):
            neighbors = [(p + dp, q + dq) for dp in range(-1, 2) for dq in range(-1, 2) if (dp, dq) != (0, 0)]
            for (pn, qn) in neighbors:
                if 0 <= pn < H and 0 <= qn < W:
                    smooth_loss += bf[p, q] * torch.abs(dx[p, q] - dx[pn, qn])
                    smooth_loss += bf[p, q] * torch.abs(dy[p, q] - dy[pn, qn])

    return smooth_loss / (B * C * H * W)

# Example usage:
# fused_image: Tensor of shape (B, C, H, W)
# deformation_field: Tensor of shape (B, 2, H, W)
alpha = 0.5
loss = smoothness_loss(deformation_field, fused_image, alpha)
print("Smoothness loss:", loss.item())

 

  • 10
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值