TV Loss详解

TV Loss介绍

 TV loss全称Total Variation Loss,其作用主要是降噪,图像中相邻像素值的差异可以通过降低TV Loss来一定程度上进行解决 ,从而保持图像的光滑性。

TV Loss定义

 连续TV Loss的定义为:
J T 0 ( u ) = ∫ D u u x 2 + u y 2 d x d y J_{T_0}(u)=\int_{D_u}\sqrt{u^2_x+u^2_y}dxdy JT0(u)=Duux2+uy2 dxdy其中 u x = ∂ u ∂ x u_x=\frac{\partial u}{\partial x} ux=xu u y = ∂ u ∂ y u_y=\frac{\partial u}{\partial y} uy=yu D u D_u Du是定义域。

 带阶数的TV Loss的定义为:
J T 0 ( u ) = ∫ D u ( u x 2 + u y 2 ) β 2 d x d y J_{T_0}(u)=\int_{D_u}(u^2_x+u^2_y)^{\frac{\beta}{2}}dxdy JT0(u)=Du(ux2+uy2)2βdxdy

 离散TV Loss的定义为:
J T 0 ( u ) = ∑ i , j ( ( x i , j − 1 − x i , j ) 2 + ( x i + 1 , j − x i , j ) 2 ) β 2 J_{T_0}(u)=\sum\limits_{i,j}((x_{i,j-1}-x_{i,j})^2+(x_{i+1,j}-x_{i,j})^2)^{\frac{\beta}{2}} JT0(u)=i,j((xi,j1xi,j)2+(xi+1,jxi,j)2)2β

TV Loss运行代码

import torch 

def tv_loss(input_t):
	temp1 = torch.cat((input_t[:, :, 1:, :], input_t[:, :, -1, :].unsqueeze(2)), 2)
	temp2 = torch.cat((input_t[:, :, :, 1:], input_t[:, :, :, -1].unsqueeze(3)), 3)
	temp = (input_t - temp1)**2 + (input_t - temp2)**2
	return temp.sum()

if __name__ == '__main__':
	input = torch.rand(4,3,32,32)
	print(tv_loss(input))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道2024

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值