代码整理: PyTorch DL AI 模型测试

深度学习模型测试代码

个人看过觉得比较合适的代码部分记录于此,以后一些部分的代码抄就完事了。随缘更新

1. Multi-Stage Progressive Image Restoration (CVPR 2021) demo.py

Code:https://github.com/swz30/MPRNet

2022/6/5:图像增强方面的论文,输入数据都是图像格式。代码简洁明了,针对其他任务则添加关于模型路径的参数,修改模型读取处的代码;有需要则添加计算运行时间函数,输出运行平均时间。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
import os
from runpy import run_path
from skimage import img_as_ubyte
from collections import OrderedDict
from natsort import natsorted
from glob import glob
import cv2
import argparse

# 输入路径,输出路径,所选择任务
parser = argparse.ArgumentParser(description='Demo MPRNet')
parser.add_argument('--input_dir', default='./samples/input/', type=str, help='Input images')
parser.add_argument('--result_dir', default='./samples/output/', type=str, help='Directory for results')
parser.add_argument('--task', required=True, type=str, help='Task to run', choices=['Deblurring', 'Denoising', 'Deraining'])

args = parser.parse_args()

def save_img(filepath, img):
	# 保存图像
    cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

def load_checkpoint(model, weights):
	# 加载权重
    checkpoint = torch.load(weights)
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

task    = args.task
inp_dir = args.input_dir
out_dir = args.result_dir

# exist_ok=True,则目标路径存在不会触发异常
os.makedirs(out_dir, exist_ok=True)

# 自然排序,后缀相同时自然排序
files = natsorted(glob(os.path.join(inp_dir, '*.jpg'))
                + glob(os.path.join(inp_dir, '*.JPG'))
                + glob(os.path.join(inp_dir, '*.png'))
                + glob(os.path.join(inp_dir, '*.PNG')))

# 找不到文件抛出错误
if len(files) == 0:
    raise Exception(f"No files found at {inp_dir}")

# Load corresponding model architecture and weights
load_file = run_path(os.path.join(task, "MPRNet.py"))
# 把MPRNet.py文件中的MPRNet类取出来了
model = load_file['MPRNet']()
model.cuda()

# 读取参数路径,并设为eval
weights = os.path.join(task, "pretrained_models", "model_"+task.lower()+".pth")
load_checkpoint(model, weights)
model.eval()

# 因为网络结构中使用了U-Net结构,所以对于输入都要进行调整
img_multiple_of = 8

for file_ in files:
    img = Image.open(file_).convert('RGB')
    input_ = TF.to_tensor(img).unsqueeze(0).cuda()

    # Pad the input if not_multiple_of 8
    # 如果输入图像不满足8的倍数则padding
    h,w = input_.shape[2], input_.shape[3]
    H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
    padh = H-h if h%img_multiple_of!=0 else 0
    padw = W-w if w%img_multiple_of!=0 else 0
    input_ = F.pad(input_, (0,padw,0,padh), 'reflect')

    # 得到输出,限制其范围[min,max]
    with torch.no_grad():
        restored = model(input_)
    restored = restored[0]  # 有多个阶段的输出,0号位置的是最终输出
    restored = torch.clamp(restored, 0, 1)

    # Unpad the output
    # 输出的图像又去掉padding的部分
    restored = restored[:,:,:h,:w]

    # [B,H,W,C]
    restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
    # 将图像转换为8位无符号整数格式
    restored = img_as_ubyte(restored[0])

    # 先拆开前面的目录取出文件名,然后去除文件格式
    # "A/B/c.jpg"-->"c.jpg"-->"c"
    f = os.path.splitext(os.path.split(file_)[-1])[0]
    # 保存文件
    save_img((os.path.join(out_dir, f+'.png')), restored)

# 输出表示结束
print(f"Files saved at {out_dir}")

2. FFA-Net: Feature Fusion Attention Network for Single Image Dehazing (AAAI 2020) test.py

Code:https://github.com/zhilin007/FFA-Net

2022/6/6:FFA-Net是一篇图像去雾方向的论文,这篇论文曾经复现并进行改进,虽然在RESIDE数据集上有着非常好的指标,但在真实师姐场景几乎是不起作用的,或者说只有对近景的薄雾有一丢丢作用。代码的命名略显粗糙,tensorShow值得Copy一下。

import os,argparse
import numpy as np
from PIL import Image
from models import *
import torch
import torch.nn as nn
import torchvision.transforms as tfs 
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
# 当前工作目录绝对路径
abs=os.getcwd()+'/'
# 输入输出图像展示
def tensorShow(tensors,titles=['haze']):
        fig=plt.figure()
        for tensor,tit,i in zip(tensors,titles,range(len(tensors))):
            img = make_grid(tensor)
            npimg = img.numpy()
            ax = fig.add_subplot(221+i)
            ax.imshow(np.transpose(npimg, (1, 2, 0)))
            ax.set_title(tit)
        plt.show()

# 参数设置
parser=argparse.ArgumentParser()
parser.add_argument('--task',type=str,default='its',help='its or ots')
parser.add_argument('--test_imgs',type=str,default='test_imgs',help='Test imgs folder')
opt=parser.parse_args()
dataset=opt.task

# 模型的超参数
gps=3
blocks=19
# 测试图像路径和输出路径
img_dir=abs+opt.test_imgs+'/'
output_dir=abs+f'pred_FFA_{dataset}/'
print("pred_dir:",output_dir)

# 创建输出路径
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

# 模型参数路径
model_dir=abs+f'trained_models/{dataset}_train_ffa_{gps}_{blocks}.pk'

# 加载模型,加载参数,调整为eval()
device='cuda' if torch.cuda.is_available() else 'cpu'
ckp=torch.load(model_dir,map_location=device)
net=FFA(gps=gps,blocks=blocks)
net=nn.DataParallel(net)
net.load_state_dict(ckp['model'])
net.eval()

# 读取路径中的图片
for im in os.listdir(img_dir):
    print(f'\r {im}',end='',flush=True)
    haze = Image.open(img_dir+im)
    haze1= tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])
    ])(haze)[None,::]
    haze_no=tfs.ToTensor()(haze)[None,::]
    with torch.no_grad():
        pred = net(haze1)
    ts=torch.squeeze(pred.clamp(0,1).cpu())
    # 每一对输入输出图像都会展示出来
    tensorShow([haze_no,pred.clamp(0,1).cpu()],['haze','pred'])
    # "A/B/c.jpg"-->"A/B/c"-->"A/B/c_FFA.png"
    vutils.save_image(ts,output_dir+im.split('.')[0]+'_FFA.png')

3. COIN: COmpression with Implicit Neural representations 绘图代码 plots.py

Code:https://github.com/EmilienDupont/coin

# Based on https://github.com/InterDigitalInc/CompressAI/blob/master/compressai/utils/plot/__main__.py
import imageio
import json5 as json
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from matplotlib import cm
from pathlib import Path


ours = 'COIN'

# Ensure consistent coloring across plots
# 为每个折线选择不同的颜色
name_to_color = {
    ours: mcolors.TABLEAU_COLORS['tab:blue'],
    'BMS': mcolors.TABLEAU_COLORS['tab:orange'],
    'MBT': mcolors.TABLEAU_COLORS['tab:green'],
    'CST': mcolors.TABLEAU_COLORS['tab:red'],
    'JPEG': mcolors.TABLEAU_COLORS['tab:purple'],
    'JPEG2000': mcolors.TABLEAU_COLORS['tab:brown'],
    'BPG': mcolors.TABLEAU_COLORS['tab:pink'],
    'VTM': mcolors.TABLEAU_COLORS['tab:gray'],
}

# Setup colormap for residuals plot
# 
viridis = cm.get_cmap('viridis', 100)


def parse_json_file(filepath, metric='psnr'):
    """Parses a json result file.

    Args:
        filepath (string): Path to results json file.
        metric (string): Metric to use for plot.
    """
    # 路径
    filepath = Path(filepath)
    name = filepath.name.split('.')[0]
    # 读取
    with filepath.open('r') as f:
        try:
            data = json.load(f)
        except json.decoder.JSONDecodeError as err:
            print(f'Error reading file {filepath}')
            raise err
	
	# 确保json文件格式没问题,这些文件都能再coin-main/baselines里面找到
    if 'results' not in data or 'bpp' not in data['results']:
        raise ValueError(f'Invalid file {filepath}')
	
	# 也是确保格式没问题
    if metric not in data['results']:
        raise ValueError(
            f'Error: metric {metric} not available.'
            f' Available metrics: {", ".join(data["results"].keys())}'
        )
	
	#查看ms-ssim的话转成db
    if metric == 'ms-ssim':
        # Convert to db
        values = np.array(data['results'][metric])
        data['results'][metric] = -10 * np.log10(1 - values)
	
	# 返回name,bpp和metric
    return {
        'name': data.get('name', name),
        'xs': data['results']['bpp'],
        'ys': data['results'][metric],
    }


def rate_distortion(scatters, title=None, ylabel='PSNR [dB]', output_file=None,
                    limits=None, show=False, figsize=None):
    """Creates a rate distortion plot based on scatters.

    Args:
        scatters (list of dicts): List of data to plot for each model.
        title (string):
        ylabel (string):
        output_file (string): If not None, save plot at output_file.
        limits (tuple of ints):
        show (bool): If True shows plot.
        figsize (tuple of ints):
    """
    if figsize is None:
        figsize = (7, 4)
    fig, ax = plt.subplots(figsize=figsize)
    for sc in scatters:
        if sc['name'] == ours:
            linewidth = 2.5
            markersize = 10
        else:
            linewidth = 1
            markersize = 6

        if sc['name'] in [ours, 'BMS', 'MBT', 'CST']:
            pattern = '.-'  # Learned algorithms
        else:
            pattern = '.--'  # Non learned algorithms
        ax.plot(sc['xs'], sc['ys'], pattern, label=sc['name'],
                c=name_to_color[sc['name']], linewidth=linewidth,
                markersize=markersize)

    ax.set_xlabel('Bit-rate [bpp]')
    ax.set_ylabel(ylabel)
    ax.grid()
    if limits is not None:
        ax.axis(limits)
    ax.legend(loc='lower right')

    if title:
        ax.title.set_text(title)

    if show:
        plt.show()

    if output_file:
        fig.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()


def plot_rate_distortion(filepaths=['results.json',
                                    'baselines/compressai-bmshj2018-hyperprior.json',
                                    'baselines/compressai-mbt2018.json',
                                    'baselines/compressai-cheng2020-anchor.json',
                                    'baselines/jpeg.json', 'baselines/jpeg2000.json',
                                    'baselines/bpg_444_x265_ycbcr.json',
                                    'baselines/vtm.json'],
                         output_file=None, limits=None):
    """Creates rate distortion plot based on all results json files.

    Args:
        filepaths (list of string): List of paths to result json files.
        output_file (string): Path to save image.
        limits (tuple of float): Limits of plot.
    """
    # Read data
    scatters = []
    for f in filepaths:
        rv = parse_json_file(f, 'psnr')
        scatters.append(rv)
    # Create plot
    rate_distortion(scatters, output_file=output_file, limits=limits)


def plot_model_size(output_file=None, show=False):
    """Plots histogram of model sizes.

    Args:
        output_file (string): If not None, save plot at output_file.
        show (bool): If True shows plot.

    Notes:
        Data for all baselines was computed using the compressAI library
        https://github.com/InterDigitalInc/CompressAI
    """
    model_names = ['COIN', 'BMS', 'MBT', 'CST']
    model_sizes = [14.7455, 10135.868, 24764.604, 31834.464]  # in kB

    plt.grid(zorder=0, which="both", axis="y")  # Ensure grid is at the back

    barplot = plt.bar(model_names, model_sizes, log=True, zorder=10)
    for i in range(len(model_names)):
        barplot[i].set_color(name_to_color[model_names[i]])
    plt.ylabel("Model size [kB]")

    fig = plt.gcf()
    fig.set_size_inches(3, 4)

    if show:
        plt.show()

    if output_file:
        plt.savefig(output_file, format='png', dpi=400, bbox_inches='tight')
        plt.clf()
        plt.close()


def plot_residuals(path_original='kodak-dataset/kodim15.png',
                   path_coin='imgs/kodim15_coin_bpp_03.png',
                   path_jpeg='imgs/kodim15_jpeg_bpp_03.jpg',
                   output_file=None, show=False, max_residual=0.3,
                   title_fontsize=6):
    """Creates a plot comparing compression with COIN and JPEG both in terms of
    the compressed image and the residual between the compressed and original
    image.


    Args:
        path_original (string): Path to original image.
        path_coin (string): Path to image compressed with COIN.
        path_jpeg (string): Path to image compressed with JPEG.
        output_file (string): If not None, save plot at output_file.
        show (bool): If True shows plot.
        max_residual (float): Value between 0 and 1 to use for maximum residual
            on color scale. Usually set to a low value so residuals are clearer
            on plot.
    """
    # Load images and compute residuals
    img_original = imageio.imread(path_original) / 255.
    img_coin = imageio.imread(path_coin) / 255.
    img_jpeg = imageio.imread(path_jpeg) / 255.
    residual_coin = viridis(np.abs(img_coin - img_original).mean(axis=-1) / max_residual)[:, :, :3]
    residual_jpeg = viridis(np.abs(img_jpeg - img_original).mean(axis=-1) / max_residual)[:, :, :3]

    # Create plot
    plt.subplot(2, 3, 1)
    plt.imshow(img_original)
    plt.axis('off')
    plt.gca().set_title('Original', fontsize=title_fontsize)

    plt.subplot(2, 3, 2)
    plt.imshow(img_coin)
    plt.axis('off')
    plt.gca().set_title('COIN', fontsize=title_fontsize)

    plt.subplot(2, 3, 3)
    plt.imshow(residual_coin)
    plt.axis('off')
    plt.gca().set_title('COIN Residual', fontsize=title_fontsize)

    plt.subplot(2, 3, 5)
    plt.imshow(img_jpeg)
    plt.axis('off')
    plt.gca().set_title('JPEG', fontsize=title_fontsize)

    plt.subplot(2, 3, 6)
    plt.imshow(residual_jpeg)
    plt.axis('off')
    plt.gca().set_title('JPEG Residual', fontsize=title_fontsize)

    plt.subplots_adjust(wspace=0.1, hspace=0)

    if show:
        plt.show()

    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.clf()
        plt.close()


if __name__ == '__main__':
    plot_rate_distortion(output_file='rate_distortion.png',
                         limits=(0, 1, 22, 38))
    plot_model_size(output_file='model_sizes.png')
    plot_residuals(output_file='residuals_kodim15_bpp_03.png')
    plot_residuals(output_file='residuals_kodim15_bpp_015.png',
                   path_coin='imgs/kodim15_coin_bpp_015.png',
                   path_jpeg='imgs/kodim15_jpeg_bpp_015.jpg')

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

倘若我问心无愧呢丶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值