计算PSNR
计算PSNR函数:
def PSNR(pred, gt,height,width):
pred = np.array(pred)
gt = np.array(gt)
pred=pred.reshape(-1,height* width)
gt = gt.reshape(-1, height* width)
res = np.mean((pred - gt) ** 2, axis=1)
res=res.reshape(-1,1)
res=np.sqrt(res)
res[res == 0.0] = 0.01
psnr = 20 * np.log10(255.0 / (res)) #像素值是8比特就用这个
# psnr = 20 * np.log10(1.0/ (res)) #如果归一化了就用这个
return psnr
完整代码:
先读批量取预测图像pred和真实图像groundtruth的分量值,然后使用我们定义的PSNR函数单独计算每张图片的y分量的PSNR。最后打印出平均PSNR。
import argparse, os
import torch
import random
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from vdsr import Net
from dataset import DatasetFromHdf5
import numpy as np
def PSNR(pred, gt,height,width):
pred = np.array(pred)
gt = np.array(gt)
pred=pred.reshape(-1,height* width)
gt = gt.reshape(-1, height* width)
res = np.mean((pred - gt) ** 2, axis=1)
res=res.reshape(-1,1)
res=np.sqrt(res)
res[res == 0.0] = 0.01
psnr = 20 * np.log10(255.0 / (res))
# psnr = 20 * np.log10(1.0/ (res))
return psnr
n=0
sum_psnr=0.0
rootpath_qp='D:/编码与NN/265_cut/37_wo'
rootpath_gt='D:/编码与NN/265_cut/ori'
for i in os.listdir(rootpath_qp):
n=n+1
subfile_qp = rootpath_qp + '/' + i
size = i.split('.')[0]
num =size.split('_')[0]
size=size.split('_')[1]
width = size.split('x')[0]
Width_Y = int(width)
height = size.split('x')[1]
Height_Y = int(height)
subfile_gt = rootpath_gt + '/' + i
fp1 = open(subfile_qp, 'rb')
Y1 = np.frombuffer(fp1.read(Height_Y * Width_Y * 2 // 2), np.uint8).reshape((Height_Y , Width_Y))
U1 = np.frombuffer(fp1.read(Height_Y * Width_Y // 2 // 2), np.uint8).reshape((Height_Y // 2 , Width_Y // 2))
V1 = np.frombuffer(fp1.read(Height_Y * Width_Y // 2 // 2), np.uint8).reshape((Height_Y // 2 ,Width_Y // 2))
fp2 = open(subfile_gt, 'rb')
Y2 = np.frombuffer(fp2.read(Height_Y * Width_Y * 2 // 2), np.uint8).reshape((Height_Y, Width_Y))
U2 = np.frombuffer(fp2.read(Height_Y * Width_Y // 2 // 2), np.uint8).reshape((Height_Y // 2, Width_Y // 2))
V2 = np.frombuffer(fp2.read(Height_Y * Width_Y // 2 // 2), np.uint8).reshape((Height_Y // 2, Width_Y // 2))
sum_psnr=sum_psnr+PSNR(Y1,Y2,Height_Y,Width_Y)
print(sum_psnr)
print(sum_psnr/n)