计算图像的PSNR和SSIM(仿照compressai)
计算图像的PSNR和SSIM(仿照compressai)
按照compressai提供的计算方式 计算psnr和ssim
可以批量计算图像的psnr和msssim,并且每张图片的计算结果保存在1.txt里面
python代码:
# Copyright 2020 by Gongfan Fang, Zhejiang University.
# All rights reserved.
import warnings
import torch
import torch.nn.functional as F
import abc
import io
import os
import platform
import subprocess
import sys
import time
from tempfile import mkstemp
from typing import Dict, List, Optional, Union
import numpy as np
import PIL
import PIL.Image as Image
import torch
from PIL import Image
from tqdm import tqdm
def _fspecial_gauss_1d(size, sigma):
r"""Create 1-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 1D kernel (1 x 1 x size)
"""
coords = torch.arange(size).to(dtype=torch.float)
coords -= size // 2
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
g /= g.sum()
return g.unsqueeze(0).unsqueeze(0)
def gaussian_filter(input, win):
r""" Blur input with 1-D kernel
Args:
input (torch.Tensor): a batch of tensors to be blurred
window (torch.Tensor): 1-D gauss kernel
Returns:
torch.Tensor: blurred tensors
"""
assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
if len(input.shape) == 4:
conv = F.conv2d
elif len(input.shape) == 5:
conv = F.conv3d
else:
raise NotImplementedError(input.shape)
C = input.shape[1]
out = input
for i, s in enumerate(input.shape[2:]):
if s >= win.shape[-1]:
out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
else:
warnings.warn(
f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
)
return out
def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)):
r""" Calculate ssim index for X and Y
Args:
X (torch.Tensor): images
Y (torch.Tensor): images
win (torch.Tensor): 1-D gauss kernel
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
Returns:
torch.Tensor: ssim results.
"""
K1, K2 = K
# batch, channel, [depth,] height, width = X.shape
compensation = 1.0
C1 = (K1 * data_range) ** 2
C2 = (K2 * data_range) ** 2
win = win.to(X.device, dtype=X.dtype)
mu1 = gaussian_filter(X, win)
mu2 = gaussian_filter(Y, win)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
cs = torch.flatten(cs_map, 2).mean(-1)
return ssim_per_channel, cs
def ssim(
X,
Y,
data_range=255,
size_average=True,
win_size=11,
win_sigma=1.5,
win=None,
K=(0.01, 0.03),
nonnegative_ssim=False,
):
r""" interface of ssim
Args:
X (torch.Tensor): a batch of images, (N,C,H,W)
Y (torch.Tensor): a batch of images, (N,C,H,W)
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
Returns:
torch.Tensor: ssim results
"""
if not X.shape == Y.shape:
raise ValueError("Input images should have the same dimensions.")
for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
if len(X.shape) not in (4, 5):
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
if not X.type() == Y.type():
raise ValueError("Input images should have the same dtype.")
if win is not None: # set win_size
win_size = win.shape[-1]
if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
if nonnegative_ssim:
ssim_per_channel = torch.relu(ssim_per_channel)
if size_average:
return ssim_per_channel.mean()
else:
return ssim_per_channel.mean(1)
def ms_ssim(
X, Y, data_range=255, size_average=True, win_size=11, win_sigma=1.5, win=None, weights=None, K=(0.01, 0.03)
):
r""" interface of ms-ssim
Args:
X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
Returns:
torch.Tensor: ms-ssim results
"""
if not X.shape == Y.shape:
raise ValueError("Input images should have the same dimensions.")
for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
if not X.type() == Y.type():
raise ValueError("Input images should have the same dtype.")
if len(X.shape) == 4:
avg_pool = F.avg_pool2d
elif len(X.shape) == 5:
avg_pool = F.avg_pool3d
else:
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
if win is not None: # set win_size
win_size = win.shape[-1]
if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
smaller_side = min(X.shape[-2:])
assert smaller_side > (win_size - 1) * (
2 ** 4
), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4))
if weights is None:
weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
weights = torch.FloatTensor(weights).to(X.device, dtype=X.dtype)
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
levels = weights.shape[0]
mcs = []
for i in range(levels):
ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)
if i < levels - 1:
mcs.append(torch.relu(cs))
padding = [s % 2 for s in X.shape[2:]]
X = avg_pool(X, kernel_size=2, padding=padding)
Y = avg_pool(Y, kernel_size=2, padding=padding)
ssim_per_channel = torch.relu(ssim_per_channel) # (batch, channel)
mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel)
ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0)
if size_average:
return ms_ssim_val.mean()
else:
return ms_ssim_val.mean(1)
class SSIM(torch.nn.Module):
def __init__(
self,
data_range=255,
size_average=True,
win_size=11,
win_sigma=1.5,
channel=3,
spatial_dims=2,
K=(0.01, 0.03),
nonnegative_ssim=False,
):
r""" class for ssim
Args:
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
channel (int, optional): input channels (default: 3)
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
"""
super(SSIM, self).__init__()
self.win_size = win_size
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
self.size_average = size_average
self.data_range = data_range
self.K = K
self.nonnegative_ssim = nonnegative_ssim
def forward(self, X, Y):
return ssim(
X,
Y,
data_range=self.data_range,
size_average=self.size_average,
win=self.win,
K=self.K,
nonnegative_ssim=self.nonnegative_ssim,
)
class MS_SSIM(torch.nn.Module):
def __init__(
self,
data_range=255,
size_average=True,
win_size=11,
win_sigma=1.5,
channel=3,
spatial_dims=2,
weights=None,
K=(0.01, 0.03),
):
r""" class for ms-ssim
Args:
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
channel (int, optional): input channels (default: 3)
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
"""
super(MS_SSIM, self).__init__()
self.win_size = win_size
self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
self.size_average = size_average
self.data_range = data_range
self.weights = weights
self.K = K
def forward(self, X, Y):
return ms_ssim(
X,
Y,
data_range=self.data_range,
size_average=self.size_average,
win=self.win,
weights=self.weights,
K=self.K,
)
def read_image(filepath: str, mode: str = "RGB") -> np.array:
"""Return PIL image in the specified `mode` format."""
if not os.path.isfile(filepath):
raise ValueError(f'Invalid file "{filepath}".')
return Image.open(filepath).convert(mode)
def _compute_ms_ssim(a, b, max_val: float = 255.0) -> float:
return ms_ssim(a, b, data_range=max_val).item()
def _compute_psnr(a, b, max_val: float = 255.0) -> float:
mse = torch.mean((a - b) ** 2).item()
psnr = 20 * np.log10(max_val) - 10 * np.log10(mse)
return psnr
_metric_functions = {
"psnr": _compute_psnr,
"ms-ssim": _compute_ms_ssim,
}
def compute_metrics1(
a: Union[np.array, Image.Image],
b: Union[np.array, Image.Image],
max_val: float = 255.0,
) -> Dict[str, float]:
"""Returns PSNR and MS-SSIM between images `a` and `b`."""
def _convert(x):
if isinstance(x, Image.Image):
x = np.asarray(x)
x = torch.from_numpy(x.copy()).float().unsqueeze(0)
if x.size(3) == 3:
# (1, H, W, 3) -> (1, 3, H, W)
x = x.permute(0, 3, 1, 2)
return x
a = _convert(a)
b = _convert(b)
out = _compute_ms_ssim(a, b, max_val)
return out
def compute_metrics(
a: Union[np.array, Image.Image],
b: Union[np.array, Image.Image],
metrics: Optional[List[str]] = None,
max_val: float = 255.0,
) -> Dict[str, float]:
"""Returns PSNR and MS-SSIM between images `a` and `b`."""
if metrics is None:
metrics = ["psnr"]
def _convert(x):
if isinstance(x, Image.Image):
x = np.asarray(x)
x = torch.from_numpy(x.copy()).float().unsqueeze(0)
if x.size(3) == 3:
# (1, H, W, 3) -> (1, 3, H, W)
x = x.permute(0, 3, 1, 2)
return x
a = _convert(a)
b = _convert(b)
out = {}
for metric_name in metrics:
out[metric_name] = _metric_functions[metric_name](a, b, max_val)
return out
def _load_img(img):
return read_image(os.path.abspath(img))
in_filepath = '原始图像路径'
out_filepath = '重构图像路径'
filelist = os.listdir('重构图像路径')
for fichier in filelist[:]:
if not(fichier.endswith(".png")):
filelist.remove(fichier)
list_psnr = []
list_ssim = []
file = open(r'1.txt', mode='w', encoding='utf-8')
for im in tqdm(filelist):
metrics = compute_metrics(_load_img(out_filepath+im), _load_img(in_filepath+im), ["psnr", "ms-ssim"])
file.write("{}:{}".format(im, metrics) + '\n')
list_psnr.append(metrics['psnr'])
list_ssim.append(metrics['ms-ssim'])
print("mean_psnr:{}".format(np.mean(list_psnr)))
print("mean_ssim:{}".format(np.mean(list_ssim)))
file.write("mean_psnr:{}".format(np.mean(list_psnr)) + '\n')
file.write("mean_ssim:{}".format(np.mean(list_ssim)) + '\n')