去模糊质量衡量指标SSIM,PSNR和推理时间

所有对于算法的改进都是有一个目标的,比如使得图像的质量更好,又或者处理的速度更快,这些都是用来衡量去模糊算法的好坏程度的指标。在图像处理部分,有两个比较常用的衡量图像质量的指标如下
SSIM (Structural SIMilarity) 结构相似性:一种全参考的图像质量评价指标,它分别从亮度、对比度、结构三*方面度量图像相似性
PSNR (Peak Signal-to-Noise Ratio) 峰值信噪比:一种全参考的图像质量评价指标
计算两个指标的代码如下

"""这段代码是一个用于测试图像的脚本,包括图像的预处理、模型推断、评价指标计算等功能"""
from __future__ import print_function
import argparse
import numpy as np
import torch
import cv2
import yaml
import os
from torchvision import models, transforms
from torch.autograd import Variable
import shutil
import glob
import tqdm
#from util.metrics import PSNR 源码修改的地方
from albumentations import Compose, CenterCrop, PadIfNeeded
from PIL import Image
#from ssim.ssimlib import SSIM 源码修改的地方

from skimage.metrics import structural_similarity as SSIM #新增的地方,导入相关的库
from skimage.metrics import peak_signal_noise_ratio as PSNR #新增的地方,导入相关的库
from models.networks import get_generator


def get_args(): #定义一个函数get_args,用于解析命令行参数,包括图像文件夹路径和权重文件路径
	#1.测试批量图片
	parser = argparse.ArgumentParser('Test an image')
	parser.add_argument('--img_folder', required=True, help='GoPRO Folder') # default='GOPRO', 图像文件路径
	parser.add_argument('--weights_path', required=True, help='Weights path') #default='best_fpn.h5', 权重文件路径
	return parser.parse_args()


def prepare_dirs(path): #定义了一个函数prepare_dirs,用于准备目录,如果目录存在则删除并重新创建
	if os.path.exists(path):
		shutil.rmtree(path)
	os.makedirs(path)


def get_gt_image(path): #定义了一个函数get_gt_image,用于获取原始图像
	dir, filename = os.path.split(path)
	base, seq = os.path.split(dir)
	base, _ = os.path.split(base)
	#img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB) #源码修改的地方(注释掉)
	img = cv2.cvtColor(cv2.imread(os.path.join(base, _, 'sharp', filename)), cv2.COLOR_BGR2RGB)
	return img


def test_image(model, image_path):
	"""
	测试单张图像,包括图像的预处理、模型推断、评价指标计算等

	Args:
	    model (torch.nn.Module): 训练好的模型
	    image_path (str): 图像文件路径

	Returns:
	    tuple: PSNR和SSIM评价指标
	"""
	img_transforms = transforms.Compose([
		transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
	])
	size_transform = Compose([
		PadIfNeeded(736, 1280)
	])
	crop = CenterCrop(720, 1280)
	img = cv2.imread(image_path)
	img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
	img_s = size_transform(image=img)['image']
	img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
	img_tensor = img_transforms(img_tensor)
	with torch.no_grad():
		img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
		result_image = model(img_tensor)
	result_image = result_image[0].cpu().float().numpy()
	result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
	result_image = crop(image=result_image)['image']
	result_image = result_image.astype('uint8')
	gt_image = get_gt_image(image_path)
	_, filename = os.path.split(image_path)
	psnr = PSNR(result_image, gt_image)
	# pilFake = Image.fromarray(result_image) #源码修改的地方(注释掉)
	# pilReal = Image.fromarray(gt_image) #源码修改的地方(注释掉)
	# ssim = SSIM(pilFake).cw_ssim_value(pilReal) #源码修改的地方(注释掉)
	ssim = SSIM(result_image, gt_image, multichannel=True) #新增
	return psnr, ssim


def test(model, files):
	"""
	对一组图像进行测试,计算并输出平均的PSNR和SSIM

	Args:
	    model (torch.nn.Module): 训练好的模型
	    files (list): 图像文件路径列表
	"""
	psnr = 0
	ssim = 0
	for file in tqdm.tqdm(files):
		cur_psnr, cur_ssim = test_image(model, file)
		psnr += cur_psnr
		ssim += cur_ssim
	print("PSNR = {}".format(psnr / len(files)))
	print("SSIM = {}".format(ssim / len(files)))


if __name__ == '__main__':
	args = get_args()
	with open('config/config.yaml', encoding='utf-8') as cfg:
		config = yaml.safe_load(cfg)
	model = get_generator(config['model'])
	model.load_state_dict(torch.load(args.weights_path)['model'])
	model = model.cuda()

	filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True)) #测试一组图片的路径
	test(model, filenames)

执行上述的代码需要在终端指定路径和权重文件:
python test_metrics.py --img_folder GOPRO --weights_path best_fpn.h5
推理时间:指的是图像输入模型到输出模型的运行时间,是衡量算法好坏的重要指标,一般都有达到实时性的要求,只有这样算法才有部署的可能性,才是有意义的!
输出推理时间的代码如下所示:

import torch
from models.mobilenet_v2 import MobileNetV2
import torch.utils.benchmark as benchmark

# 构建你的模型
model = MobileNetV2()

# 准备一个输入样本
input_sample = torch.randn(1, 3, 224, 224)

# 创建一个 benchmark.Timer 对象
timer = benchmark.Timer(
    stmt='model(input_sample)',
    setup='from __main__ import model, input_sample',
    globals=globals()
)

# 运行测量
time_taken = timer.timeit(100)  # 这里的 1 表示运行一次,你可以调整为更多次以获得更稳定的结果

# 使用 median 或 mean 获取时间的中位数或平均值
median_time = time_taken.median
mean_time = time_taken.mean

print("Median Forward Pass Time: {:.4f} seconds".format(median_time))
print("Mean Forward Pass Time: {:.4f} seconds".format(mean_time))

"""使用 median() 和 mean() 方法获取 Measurement 对象的中位数和平均值,然后进行格式化输出"""

注意:推理时间和模型的大小相关和图片的大小相关!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值