批量计算图片的ssim和psnr

psnr和ssim的计算

import torch
import torch.nn.functional as F
from math import log10
import cv2
import numpy as np
import torchvision
from skimage.metrics import structural_similarity as ssim
def to_psnr(frame_out, gt):
    mse = F.mse_loss(frame_out, gt, reduction='none')
    mse_split = torch.split(mse, 1, dim=0)
    mse_list = [torch.mean(torch.squeeze(mse_split[ind])).item() for ind in range(len(mse_split))]
    intensity_max = 1.0
    psnr_list = [10.0 * log10(intensity_max / mse) for mse in mse_list]
    return psnr_list

def to_ssim_skimage(dehaze, gt):
    dehaze_list = torch.split(dehaze, 1, dim=0)
    gt_list = torch.split(gt, 1, dim=0)

    dehaze_list_np = [dehaze_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
    gt_list_np = [gt_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
    ssim_list = [ssim(dehaze_list_np[ind],  gt_list_np[ind], data_range=1, multichannel=True) for ind in range(len(dehaze_list))]

    return ssim_list


def predict(gridnet, test_data_loader):

    psnr_list = []
    for batch_idx, (frame1, frame2, frame3) in enumerate(test_data_loader):
        with torch.no_grad():
            frame1 = frame1.to(torch.device('cuda'))
            frame3 = frame3.to(torch.device('cuda'))
            gt = frame2.to(torch.device('cuda'))
            # print(frame1)

            frame_out = gridnet(frame1, frame3)
            # print(frame_out)
            frame_debug = torch.cat((frame1, frame_out, gt, frame3), dim =0)
            filepath = "./image" + str(batch_idx) + '.png'
            torchvision.utils.save_image(frame_debug, filepath)
            # print(frame_out)
            # img = np.asarray(frame_out.cpu()).astype(float)
            
            # cv2.imwrite(filepath , img)



        # --- Calculate the average PSNR --- #
        psnr_list.extend(to_psnr(frame_out, gt))
    avr_psnr = sum(psnr_list) / len(psnr_list)
    return avr_psnr

有两种计算方式,一种是使用迭代器的方式,另外一种不用迭代器

1.不用迭代器

import os

import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms

from utils_test import to_ssim_skimage, to_psnr
from metrics import ssim,psnr

root = "D://Random//Involution//BoHuiBor//Image_set//Textpsnr//dense//"
imglist = os.listdir(os.path.join(root, "results_8390epoch_17.42psnr"))
GTimglist = os.listdir(os.path.join(root, "gt"))
psnr_list = []
ssim_list = []


for (im,GT) in zip(imglist,GTimglist):
    print(im,GT)
    pathim = os.path.join(root,"results_8390epoch_17.42psnr",im)
    #print(pathim)
    pathGT = os.path.join(root, "gt", GT)
    pathim=Image.open(pathim)
    pathGT = Image.open(pathGT)
    pathim_tensor=transforms.ToTensor()(pathim).unsqueeze(0)
    pathGT_tensor = transforms.ToTensor()(pathGT).unsqueeze(0)
    print(psnr(pathim_tensor, pathGT_tensor),ssim(pathim_tensor, pathGT_tensor).item())
    psnr_list.append(psnr(pathim_tensor, pathGT_tensor))
    ssim_list.append(ssim(pathim_tensor, pathGT_tensor).item())
avr_psnr = np.mean(psnr_list)
avr_ssim = np.mean(ssim_list)
print(avr_psnr,avr_ssim)

2.采用迭代器

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms

from utils_test import to_ssim_skimage, to_psnr

root = "D://Random//Involution//BoHuiBor//Image_set//epoch_33.71"
psnr_list = []
ssim_list = []
class datasets(Dataset):

    def __init__(self):
        self.X = os.listdir(os.path.join(root, "results"))
        self.Y = os.listdir(os.path.join(root, "GT"))

    def __getitem__(self, index):
        a = self.X[index]
        b = self.Y[index]
        print(a,b)
        x_path = os.path.join(root, "results", a)
        y_path = os.path.join(root, "GT", b)
        pathX = Image.open(x_path)
        pathY = Image.open(y_path)
        pathX_tensor= transforms.CenterCrop(256)(pathX)
        pathY_tensor = transforms.CenterCrop(256)(pathY)
        pathX_tensor = transforms.ToTensor()(pathX_tensor)
        pathY_tensor = transforms.ToTensor()(pathY_tensor)
        return pathX_tensor, pathY_tensor

    def __len__(self):
        return len(os.listdir(os.path.join(root, "results")))


dataset = datasets()


my_dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

for x, y in my_dataloader:
    psnr_list.extend(to_psnr(x, y))
    ssim_list.extend(to_ssim_skimage(x, y))
avr_psnr = sum(psnr_list) / len(psnr_list)
avr_ssim = sum(ssim_list) / len(ssim_list)
print(avr_psnr, avr_ssim)

如果要在服务器上跑的话,改一下路径的格式即可。

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值