使用PaddleGAN中的NAFNet进行图像去模糊

使用PaddleGAN中的NAFNet进行图像去模糊

1. 项目简介

1.1 项目背景

  • NAFNet是旷视研究院提出的用于图像复原的模型,在图像去模糊、去噪都取得了很好的性能,不仅计算高效同时性能优于之前SOTA方案,效果如下图所示。在双目超分任务上,基于NAFNet的双目超分模型NAFSSR获得NTIRE 2022的双目超分赛道冠军

在这里插入图片描述

1.2 项目目的

  • 尽管PaddleGAN中合入的是与去噪有关的训练、测试、预测等代码,但是NAFNet网络也已经放入repo中,稍作修改就可以体验其在去模糊任务上的性能
  • 本项目不涉及模型的训练,只是将NAFNet在GoPro与在REDS数据集上训练得到的两个最佳权重转换为paddle的权重,基于PaddleGAN中的NAFNet去进行图像去模糊
  • 对于torch权重转paddle权重,本项目不再赘述,代码都类似,可参考针对真实图像退化的盲图像超分BSRGAN复现,以及PPSIG:PMBANet深度图超分辨率重建模型复现

2. 如何使用

  • 首先我将NAFNet的deblur权重转换为Paddle的之后并挂载在项目的数据集中,一共有两个权重:
    • 在GoPro数据集上训练得到的NAFNet-GoPro-width64.pdparams, 主要用于运动模糊图像的去除
    • 在REDS数据集上训练得到的NAFNet-REDS-width64.pdparams,主要用于有压缩损失的模糊图像恢复
  • 接下来则是基于PaddleGAN来调用该权重,对我们手上的影像去模糊。Follow me!
# 克隆仓库,该步骤犹豫由于外网限速,比较慢,所以可以直接使用已经克隆下来的文件夹,不必执行本步骤
# !git clone https://github.com/PaddlePaddle/PaddleGAN
正克隆到 'PaddleGAN'...
remote: Enumerating objects: 5401, done.[K
remote: Counting objects: 100% (203/203), done.[K
remote: Compressing objects: 100% (159/159), done.[K
remote: Total 5401 (delta 101), reused 95 (delta 41), pack-reused 5198[K
接收对象中: 100% (5401/5401), 163.52 MiB | 10.71 MiB/s, 完成.
处理 delta 中: 100% (3499/3499), 完成.
检查连接... 完成。
# 安装依赖
%cd PaddleGAN/
!pip install -r requirements.txt
import cv2
from glob import glob
from natsort import natsorted
import numpy as np
import os
from tqdm import tqdm

import paddle

from ppgan.models.generators import NAFNetLocal
from ppgan.utils.download import get_path_from_url
from ppgan.apps.base_predictor import BasePredictor

# 模型参数定义
model_cfgs = {
    'Deblur': {
        'img_channel': 3,
        'width': 64,
        'enc_blk_nums': [1, 1, 1, 28],
        'middle_blk_num': 1,
        'dec_blk_nums': [1, 1, 1, 1]
    }
}

# 定义去模糊的预测类
class NAFNetDeblurer(BasePredictor):

    def __init__(self,
                 output_path='output_dir',
                 weight_path=None):
        self.output_path = output_path
        task = 'Deblur'
        self.task = task

        checkpoint = paddle.load(weight_path)

        self.generator = NAFNetLocal(
            img_channel=model_cfgs[task]['img_channel'],
            width=model_cfgs[task]['width'],
            enc_blk_nums=model_cfgs[task]['enc_blk_nums'],
            middle_blk_num=model_cfgs[task]['middle_blk_num'],
            dec_blk_nums=model_cfgs[task]['dec_blk_nums'])

        self.generator.set_state_dict(checkpoint)
        self.generator.eval()

    def get_images(self, images_path):
        if os.path.isdir(images_path):
            return natsorted(
                glob(os.path.join(images_path, '*.jpeg')) +
                glob(os.path.join(images_path, '*.jpg')) +
                glob(os.path.join(images_path, '*.JPG')) +
                glob(os.path.join(images_path, '*.png')) +
                glob(os.path.join(images_path, '*.PNG')))
        else:
            return [images_path]

    def imread_uint(self, path, n_channels=3):
        #  input: path
        # output: HxWx3(RGB or GGG), or HxWx1 (G)
        if n_channels == 1:
            img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE
            img = np.expand_dims(img, axis=2)  # HxWx1
        elif n_channels == 3:
            img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G
            if img.ndim == 2:
                img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG
            else:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB

        return img

    def uint2single(self, img):

        return np.float32(img / 255.)

    # convert single (HxWxC) to 3-dimensional paddle tensor
    def single2tensor3(self, img):
        return paddle.Tensor(np.ascontiguousarray(
            img, dtype=np.float32)).transpose([2, 0, 1])

    def run(self, images_path=None):
        os.makedirs(self.output_path, exist_ok=True)
        task_path = os.path.join(self.output_path, self.task)
        os.makedirs(task_path, exist_ok=True)
        image_files = self.get_images(images_path)
        for image_file in tqdm(image_files):
            img_L = self.imread_uint(image_file, 3)

            image_name = os.path.basename(image_file)
            img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
            cv2.imwrite(os.path.join(task_path, image_name), img)

            tmps = image_name.split('.')
            assert len(
                tmps) == 2, f'Invalid image name: {image_name}, too much "."'
            restoration_save_path = os.path.join(
                task_path, f'{tmps[0]}_restoration.{tmps[1]}')

            img_L = self.uint2single(img_L)

            # HWC to CHW, numpy to tensor
            img_L = self.single2tensor3(img_L)
            img_L = img_L.unsqueeze(0)
            with paddle.no_grad():
                output = self.generator(img_L)

            restored = paddle.clip(output, 0, 1)

            restored = restored.numpy()
            restored = restored.transpose(0, 2, 3, 1)
            restored = restored[0]
            restored = restored * 255
            restored = restored.astype(np.uint8)

            cv2.imwrite(restoration_save_path,
                        cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))

        print('Done, output path is:', task_path)

2.1 普通图像的去模糊

  • 一般图像的大小不会超过4k,所以可以直接将图像送入网络中,执行以下操作即可

:本项目示范所用权重为基于REDS数据集训练的权重

# 定义输出路径
output_path = r"../work/output"
# 定义权重所在路径
weight_path = r"../data/data174576/NAFNet-REDS-width64.pdparams" 
# 定义去模糊类
deblur_predictor = NAFNetDeblurer(output_path, weight_path)
W1030 19:54:46.673647   192 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1030 19:54:46.677815   192 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
# 定义输入路径
input_path = r"../work/inputs/"
# 执行预测
deblur_predictor.run(images_path=input_path)
100%|██████████| 3/3 [00:03<00:00,  1.42s/it]

Done, output path is: ../work/output/Deblur
  • 对预测的结果进行展示
# 展示预测的结果
import numpy as np
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

def imread(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def display(img1, img2):
    fig = plt.figure(figsize=(25, 10))
    ax1 = fig.add_subplot(1, 2, 1) 
    plt.title('Input image', fontsize=16)
    ax1.axis('off')
    ax2 = fig.add_subplot(1, 2, 2)
    plt.title('NAFNet output', fontsize=16)
    ax2.axis('off')
    ax1.imshow(img1)
    ax2.imshow(img2)
input_path = '../work/inputs/blurry-reds-1.jpg'
output_path = '../work/output/Deblur/blurry-reds-1_restoration.jpg'

img_input = imread(input_path)
img_output = imread(output_path)
display(img_input, img_output)

在这里插入图片描述

2.2 超过4k的大图去模糊

  • 对于现实中要恢复的比较大的影像,直接预测会导致爆显存(Out of Memory, OOM),所以要切块预测
  • 继承刚才设置的去模糊预测类,新增切块预测类如下
class CropPredictor(NAFNetDeblurer):
    def __init__(self,
                 output_path='output_dir',
                 weight_path=None):
        super(CropPredictor, self).__init__(output_path, weight_path)

    def crop_predict(self, img_lq):
        sf = self.sf
        tile = self.tile
        overlap = self.overlap
        b, c, h, w = img_lq.shape
        tile_overlap = overlap
        stride = tile - tile_overlap
        h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
        w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
        E = paddle.zeros([b, c, h*sf, w*sf], dtype=img_lq.dtype)
        W = paddle.zeros_like(E)

        for h_idx in h_idx_list:
            for w_idx in w_idx_list:
                h_idx = int(h_idx)
                w_idx = int(w_idx)
                in_patch = img_lq[:, :,h_idx:h_idx+tile, w_idx:w_idx+tile]
                out_patch = self.generator(in_patch)
                out_patch_mask = paddle.ones_like(out_patch)

                E[:, :, h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf] += out_patch
                W[:, :, h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf] += out_patch_mask

        output = E.divide(W)
        return output

    def run_patches(self, images_path=None, tile=1024, overlap=128):
        os.makedirs(self.output_path, exist_ok=True)
        task_path = os.path.join(self.output_path, self.task)
        os.makedirs(task_path, exist_ok=True)
        image_files = self.get_images(images_path)
        self.tile = tile
        self.overlap = overlap
        self.sf = 1

        for image_file in tqdm(image_files):
            img_L = self.imread_uint(image_file, 3)

            image_name = os.path.basename(image_file)
            img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
            cv2.imwrite(os.path.join(task_path, image_name), img)

            tmps = image_name.split('.')
            assert len(
                tmps) == 2, f'Invalid image name: {image_name}, too much "."'
            restoration_save_path = os.path.join(
                task_path, f'{tmps[0]}_restoration.{tmps[1]}')

            img_L = self.uint2single(img_L)

            # HWC to CHW, numpy to tensor
            img_L = self.single2tensor3(img_L)
            img_L = img_L.unsqueeze(0)
            with paddle.no_grad():
                output = self.crop_predict(img_L)

            restored = paddle.clip(output, 0, 1)

            restored = restored.numpy()
            restored = restored.transpose(0, 2, 3, 1)
            restored = restored[0]
            restored = restored * 255
            restored = restored.astype(np.uint8)

            cv2.imwrite(restoration_save_path,
                        cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))

        print('Done, output path is:', task_path)

# 定义滑窗预测输出路径
crop_output_path = r"../work/crop_output"
# 定义权重路径
weight_path = r"../data/data174576/NAFNet-REDS-width64.pdparams" 
# 定义滑窗去模糊类
croppredictor = CropPredictor(crop_output_path, weight_path=weight_path)
W1102 21:20:36.032223   212 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1102 21:20:36.036273   212 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
# 大图所在路径
bigimages_path = r"../work/big_inputs/"

# 开始预测
croppredictor.run_patches(images_path=bigimages_path, tile=1024, overlap=128)
100%|██████████| 1/1 [00:08<00:00,  8.01s/it]

Done, output path is: ../work/crop_output/Deblur
# 展示效果
input_path = '../work/big_inputs/beautiful.png'
output_path = '../work/crop_output/Deblur/beautiful_restoration.png'

img_input = imread(input_path)
img_output = imread(output_path)
display(img_input, img_output)

在这里插入图片描述

3. 总结

  • 本项目介绍了在我们有去模糊的任务需求时,如何使用已经合入PaddleGAN的NAFNet,对模糊图像进行恢复,其实在日常生活中还是挺实用的
  • NAFNet还可以进行双目超分,有机会把这个也做出来,挖个坑~

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

此文章为搬运
原项目链接

  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值