在图像复原和图像去噪方面使用。
全变分(Total Variation Loss)源自于图像处理中的全变分去噪(Total Variation Denoising),全变分去噪的有点是既能去噪声,又能保留原图像中的边界信息。而其他简单的去噪方法,如果线性平滑或者中值滤波,在去噪的同时会平滑图像中的边界等信息,损害图像所表达的信息。
全变分去噪的基本思想是,如果图像的细节有很多高频信息(如尖刺、噪点等),那么整幅图像的梯度幅值之和(全变分)是比较大的,如果能够使整幅图像的梯度积分之和降低,就达到了去噪的目的。
参考:(29 封私信 / 2 条消息) 如何理解全变分(Total Variation,TV)模型? - 知乎 (zhihu.com)
PyTorch实现
因为 TV Loss 是对整个 batch 的每一幅图像计算的,所以除以batch_size * c * h * w是一个比较适合作为loss的值,创建TV Loss时可传入一个可选的权重参数(惩罚系数)
作者:imxtx
链接:https://www.zhihu.com/question/47162419/answer/2585330101
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
import torch
import torch.nn as nn
import numpy as np
class TVLoss(nn.Module):
def __init__(self, weight: float=1) -> None:
"""Total Variation Loss
Args:
weight (float): weight of TV loss
"""
super().__init__()
self.weight = weight
def forward(self, x):
batch_size, c, h, w = x.size()
tv_h = torch.abs(x[:,:,1:,:] - x[:,:,:-1,:]).sum()
tv_w = torch.abs(x[:,:,:,1:] - x[:,:,:,:-1]).sum()
return self.weight * (tv_h + tv_w) / (batch_size * c * h * w)
def main():
tv_loss = TVLoss()
x = torch.rand([1, 3, 3, 3])
print(f'Input:\n{x}')
print(f'The TV Loss is {tv_loss(x).item()}')
if __name__ == "__main__":
main()
直接用一个函数表示全变分:
def total_variation(x):
"""Anisotropic TV."""
dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
return dx + dy
chatgpt对以上代码解释:
这段代码定义了一个函数total_variation
,用于计算输入张量x
的总变差(Total Variation)。
函数的实现如下:
-
dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
:这行代码计算输入张量x
在水平方向上的变差。它通过计算相邻像素之间的差的绝对值,并求取水平方向上的平均值。 -
dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
:这行代码计算输入张量x
在垂直方向上的变差。它通过计算相邻像素之间的差的绝对值,并求取垂直方向上的平均值。 -
return dx + dy
:这行代码返回水平方向和垂直方向上变差的总和作为总变差的值。
总变差是一种用于衡量图像平滑度的指标,它量化了图像中相邻像素之间的差异程度。较小的总变差值表示图像较为平滑,而较大的总变差值表示图像较为纹理丰富或边缘明显。总变差在图像处理和计算机视觉中广泛应用,用于图像去噪、图像恢复、图像分割等任务中。
理解:
最后两个维度是宽高(b,c,h,w),j该函数就是计算横向和竖向相邻两个像素点的差值,然后取绝对值在求均值,这个作为损失函数可以使得图像更平滑。