PSNR损失函数
class PSNRLoss(nn.Module):
def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
super(PSNRLoss, self).__init__()
assert reduction == 'mean'
self.loss_weight = loss_weight
self.scale = 10 / np.log(10)
self.toY = toY
self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
self.first = True
def forward(self, pred, target):
assert len(pred.size()) == 4
if self.toY:
if self.first:
self.coef = self.coef.to(pred.device)
self.first = False
pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
pred, target = pred / 255., target / 255.
pass
assert len(pred.size()) == 4
return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
#等价于self.loss_weight *10 * torch.log10(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
self.loss_weight是固定值1
其中self.scale=10 / np.log(10)
10 / np.log(10) 表示的是一个常数,它的值约为 4.34294481903。这个常数通常用于计算信噪比(SNR)以及峰值信噪比(PSNR)
在计算 PSNR 时,使用 10 / np.log(10) 的值是因为 PSNR 的定义中包含对峰值信号功率和均方误差之间的对数关系。通过取对数并乘以一个常数(10 / np.log(10)),可以将均方误差转换为以对数刻度表示的度量,从而得到更符合人类感知的图像质量评估指标。
因此,10 / np.log(10) 在 PSNR 计算中扮演着重要的角色,它将均方误差转换为了以分贝(dB)为单位的信噪比度量。
((pred - target) ** 2).mean(dim=(1, 2, 3))
其中
((pred - target) ** 2).mean(dim=(1, 2, 3))等价于MSE
表示对预测图像和目标图像之间每个像素点的差的平方进行求和,并且取平均值。这一步计算得到的结果是一个张量,其维度可能会降低,具体取决于输入张量的维度。
在这里,.mean(dim=(1, 2, 3))
的含义是对第 1、2、3 维(通常是通道、高和宽)进行求平均值。这意味着对每个样本,对每个通道、高和宽上的差的平方求和并取平均值。最终得到的结果是每个样本的均方误差(MSE),形成一个张量。
为了避免出现除零错误,代码中添加了一个小的偏置项 1e-8
,即 + 1e-8
。这样可以确保即使计算出的均方误差接近0时,也不会出现除零错误或者负无穷值。
等价函数
self.loss_weight *10 torch.log10(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
import torch
import torch.nn as nn
class PSNRLoss(nn.Module):
def __init__(self):
super(PSNRLoss, self).__init__()
self.mse = nn.MSELoss()
def forward(self, img1, img2):
mse = self.mse(img1, img2)
psnr = 10 * torch.log10(1 / mse)
return psnr