全变分(Total Variation,TV)模型

全变分损失函数用于图像去噪和复原,能保留图像边界信息。代码示例展示了如何在PyTorch中实现TVLoss,计算图像的水平和垂直方向变差,用于优化图像平滑度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在图像复原和图像去噪方面使用。

全变分(Total Variation Loss)源自于图像处理中的全变分去噪(Total Variation Denoising),全变分去噪的有点是既能去噪声,又能保留原图像中的边界信息。而其他简单的去噪方法,如果线性平滑或者中值滤波,在去噪的同时会平滑图像中的边界等信息,损害图像所表达的信息。

全变分去噪的基本思想是,如果图像的细节有很多高频信息(如尖刺、噪点等),那么整幅图像的梯度幅值之和(全变分)是比较大的,如果能够使整幅图像的梯度积分之和降低,就达到了去噪的目的。

参考:(29 封私信 / 2 条消息) 如何理解全变分(Total Variation,TV)模型? - 知乎 (zhihu.com)

PyTorch实现

因为 TV Loss 是对整个 batch 的每一幅图像计算的,所以除以batch_size * c * h * w是一个比较适合作为loss的值,创建TV Loss时可传入一个可选的权重参数(惩罚系数)

作者:imxtx
链接:https://www.zhihu.com/question/47162419/answer/2585330101
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

import torch
import torch.nn as nn
import numpy as np


class TVLoss(nn.Module):
    def __init__(self, weight: float=1) -> None:
        """Total Variation Loss

        Args:
            weight (float): weight of TV loss
        """
        super().__init__()
        self.weight = weight
    
    def forward(self, x):
        batch_size, c, h, w = x.size()
        tv_h = torch.abs(x[:,:,1:,:] - x[:,:,:-1,:]).sum()
        tv_w = torch.abs(x[:,:,:,1:] - x[:,:,:,:-1]).sum()
        return self.weight * (tv_h + tv_w) / (batch_size * c * h * w)


def main():
    tv_loss = TVLoss()
    x = torch.rand([1, 3, 3, 3])
    print(f'Input:\n{x}')
    print(f'The TV Loss is {tv_loss(x).item()}')


if __name__ == "__main__":
    main()

直接用一个函数表示全变分:

def total_variation(x):
    """Anisotropic TV."""
    dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
    dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
    return dx + dy

chatgpt对以上代码解释:

这段代码定义了一个函数total_variation,用于计算输入张量x的总变差(Total Variation)。

函数的实现如下:

  1. dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])):这行代码计算输入张量x在水平方向上的变差。它通过计算相邻像素之间的差的绝对值,并求取水平方向上的平均值。

  2. dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])):这行代码计算输入张量x在垂直方向上的变差。它通过计算相邻像素之间的差的绝对值,并求取垂直方向上的平均值。

  3. return dx + dy:这行代码返回水平方向和垂直方向上变差的总和作为总变差的值。

总变差是一种用于衡量图像平滑度的指标,它量化了图像中相邻像素之间的差异程度。较小的总变差值表示图像较为平滑,而较大的总变差值表示图像较为纹理丰富或边缘明显。总变差在图像处理和计算机视觉中广泛应用,用于图像去噪、图像恢复、图像分割等任务中。

理解:

最后两个维度是宽高(b,c,h,w),j该函数就是计算横向和竖向相邻两个像素点的差值,然后取绝对值在求均值,这个作为损失函数可以使得图像更平滑。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值