毕业论文代码

PSNR损失函数

class PSNRLoss(nn.Module):

    def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
        super(PSNRLoss, self).__init__()
        assert reduction == 'mean'
        self.loss_weight = loss_weight
        self.scale = 10 / np.log(10)
        self.toY = toY
        self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
        self.first = True

    def forward(self, pred, target):
        assert len(pred.size()) == 4
        if self.toY:
            if self.first:
                self.coef = self.coef.to(pred.device)
                self.first = False

            pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
            target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.

            pred, target = pred / 255., target / 255.
            pass
        assert len(pred.size()) == 4

        return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
    #等价于self.loss_weight *10 * torch.log10(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()

self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()

self.loss_weight是固定值1

其中self.scale=10 / np.log(10)

10 / np.log(10) 表示的是一个常数,它的值约为 4.34294481903。这个常数通常用于计算信噪比(SNR)以及峰值信噪比(PSNR)

在计算 PSNR 时,使用 10 / np.log(10) 的值是因为 PSNR 的定义中包含对峰值信号功率和均方误差之间的对数关系。通过取对数并乘以一个常数(10 / np.log(10)),可以将均方误差转换为以对数刻度表示的度量,从而得到更符合人类感知的图像质量评估指标。

因此,10 / np.log(10) 在 PSNR 计算中扮演着重要的角色,它将均方误差转换为了以分贝(dB)为单位的信噪比度量。

((pred - target) ** 2).mean(dim=(1, 2, 3))

其中((pred - target) ** 2).mean(dim=(1, 2, 3))等价于MSE

表示对预测图像和目标图像之间每个像素点的差的平方进行求和,并且取平均值。这一步计算得到的结果是一个张量,其维度可能会降低,具体取决于输入张量的维度。

在这里,.mean(dim=(1, 2, 3)) 的含义是对第 1、2、3 维(通常是通道、高和宽)进行求平均值。这意味着对每个样本,对每个通道、高和宽上的差的平方求和并取平均值。最终得到的结果是每个样本的均方误差(MSE),形成一个张量。

为了避免出现除零错误,代码中添加了一个小的偏置项 1e-8,即 + 1e-8。这样可以确保即使计算出的均方误差接近0时,也不会出现除零错误或者负无穷值。

等价函数

self.loss_weight *10 torch.log10(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()

import torch
import torch.nn as nn

class PSNRLoss(nn.Module):
    def __init__(self):
        super(PSNRLoss, self).__init__()
        self.mse = nn.MSELoss()

    def forward(self, img1, img2):
        mse = self.mse(img1, img2)
        psnr = 10 * torch.log10(1 / mse)
        return psnr


 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值