所有对于算法的改进都是有一个目标的,比如使得图像的质量更好,又或者处理的速度更快,这些都是用来衡量去模糊算法的好坏程度的指标。在图像处理部分,有两个比较常用的衡量图像质量的指标如下:
SSIM (Structural SIMilarity) 结构相似性:一种全参考的图像质量评价指标,它分别从亮度、对比度、结构三*方面度量图像相似性
PSNR (Peak Signal-to-Noise Ratio) 峰值信噪比:一种全参考的图像质量评价指标
计算两个指标的代码如下:
"""这段代码是一个用于测试图像的脚本,包括图像的预处理、模型推断、评价指标计算等功能"""
from __future__ import print_function
import argparse
import numpy as np
import torch
import cv2
import yaml
import os
from torchvision import models, transforms
from torch.autograd import Variable
import shutil
import glob
import tqdm
#from util.metrics import PSNR 源码修改的地方
from albumentations import Compose, CenterCrop, PadIfNeeded
from PIL import Image
#from ssim.ssimlib import SSIM 源码修改的地方
from skimage.metrics import structural_similarity as SSIM #新增的地方,导入相关的库
from skimage.metrics import peak_signal_noise_ratio as PSNR #新增的地方,导入相关的库
from models.networks import get_generator
def get_args(): #定义一个函数get_args,用于解析命令行参数,包括图像文件夹路径和权重文件路径
#1.测试批量图片
parser = argparse.ArgumentParser('Test an image')
parser.add_argument('--img_folder', required=True, help='GoPRO Folder') # default='GOPRO', 图像文件路径
parser.add_argument('--weights_path', required=True, help='Weights path') #default='best_fpn.h5', 权重文件路径
return parser.parse_args()
def prepare_dirs(path): #定义了一个函数prepare_dirs,用于准备目录,如果目录存在则删除并重新创建
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
def get_gt_image(path): #定义了一个函数get_gt_image,用于获取原始图像
dir, filename = os.path.split(path)
base, seq = os.path.split(dir)
base, _ = os.path.split(base)
#img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB) #源码修改的地方(注释掉)
img = cv2.cvtColor(cv2.imread(os.path.join(base, _, 'sharp', filename)), cv2.COLOR_BGR2RGB)
return img
def test_image(model, image_path):
"""
测试单张图像,包括图像的预处理、模型推断、评价指标计算等
Args:
model (torch.nn.Module): 训练好的模型
image_path (str): 图像文件路径
Returns:
tuple: PSNR和SSIM评价指标
"""
img_transforms = transforms.Compose([
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
size_transform = Compose([
PadIfNeeded(736, 1280)
])
crop = CenterCrop(720, 1280)
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_s = size_transform(image=img)['image']
img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
img_tensor = img_transforms(img_tensor)
with torch.no_grad():
img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
result_image = model(img_tensor)
result_image = result_image[0].cpu().float().numpy()
result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
result_image = crop(image=result_image)['image']
result_image = result_image.astype('uint8')
gt_image = get_gt_image(image_path)
_, filename = os.path.split(image_path)
psnr = PSNR(result_image, gt_image)
# pilFake = Image.fromarray(result_image) #源码修改的地方(注释掉)
# pilReal = Image.fromarray(gt_image) #源码修改的地方(注释掉)
# ssim = SSIM(pilFake).cw_ssim_value(pilReal) #源码修改的地方(注释掉)
ssim = SSIM(result_image, gt_image, multichannel=True) #新增
return psnr, ssim
def test(model, files):
"""
对一组图像进行测试,计算并输出平均的PSNR和SSIM
Args:
model (torch.nn.Module): 训练好的模型
files (list): 图像文件路径列表
"""
psnr = 0
ssim = 0
for file in tqdm.tqdm(files):
cur_psnr, cur_ssim = test_image(model, file)
psnr += cur_psnr
ssim += cur_ssim
print("PSNR = {}".format(psnr / len(files)))
print("SSIM = {}".format(ssim / len(files)))
if __name__ == '__main__':
args = get_args()
with open('config/config.yaml', encoding='utf-8') as cfg:
config = yaml.safe_load(cfg)
model = get_generator(config['model'])
model.load_state_dict(torch.load(args.weights_path)['model'])
model = model.cuda()
filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True)) #测试一组图片的路径
test(model, filenames)
执行上述的代码需要在终端指定路径和权重文件:
python test_metrics.py --img_folder GOPRO --weights_path best_fpn.h5
推理时间:指的是图像输入模型到输出模型的运行时间,是衡量算法好坏的重要指标,一般都有达到实时性的要求,只有这样算法才有部署的可能性,才是有意义的!
输出推理时间的代码如下所示:
import torch
from models.mobilenet_v2 import MobileNetV2
import torch.utils.benchmark as benchmark
# 构建你的模型
model = MobileNetV2()
# 准备一个输入样本
input_sample = torch.randn(1, 3, 224, 224)
# 创建一个 benchmark.Timer 对象
timer = benchmark.Timer(
stmt='model(input_sample)',
setup='from __main__ import model, input_sample',
globals=globals()
)
# 运行测量
time_taken = timer.timeit(100) # 这里的 1 表示运行一次,你可以调整为更多次以获得更稳定的结果
# 使用 median 或 mean 获取时间的中位数或平均值
median_time = time_taken.median
mean_time = time_taken.mean
print("Median Forward Pass Time: {:.4f} seconds".format(median_time))
print("Mean Forward Pass Time: {:.4f} seconds".format(mean_time))
"""使用 median() 和 mean() 方法获取 Measurement 对象的中位数和平均值,然后进行格式化输出"""
注意:推理时间和模型的大小相关和图片的大小相关!