DeblurGAN-v2:更快更好地去模糊

论文地址:DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better

Github:https://github.com/TAMU-VITA/DeblurGAN      

or   https://github.com/KupynOrest/DeblurGANv2

 概念

我们提出了一种新的端到端生成对抗网络 (GAN),用于单图像运动去模糊,名为 DeblurGAN-v2,它大大提高了最先进的去模糊效率、质量和灵活性。DeblurGAN-v2 基于具有双尺度鉴别器的相对论条件 GAN。我们首次将特征金字塔网络引入去模糊,作为 DeblurGAN-v2 生成器的核心构建块。它可以灵活地与各种主干一起工作,以在性能和效率之间取得平衡。复杂主干的插件(例如,Inception-ResNet-v2)可以实现最先进的去模糊。同时,借助轻量级骨干网(例如,MobileNet 及其变体),DeblurGAN-v2 比最接近的竞争对手快 10-100 倍,同时保持接近最先进的结果,这意味着可以选择实时视频去模糊。我们证明,在去模糊质量(客观和主观)以及效率方面,DeblurGAN-v2 在几个流行的基准测试中获得了非常有竞争力的性能。此外,我们展示了该架构对于一般图像恢复任务也很有效。

DeblurGAN-v2 架构

数据集

训练数据集可通过以下链接下载:

训练

命令

python train.py

训练脚本将在 config/config.yaml 下加载配置

张量板可视化

测试

要在单个图像上进行测试,

python predict.py IMAGE_NAME.jpg

默认情况下,Predictor 使用的预训练模型的名称是“best_fpn.h5”。可以在代码中更改它('weights_path' 参数)。它假设使用了 fpn_inception 主干。如果您想尝试使用不同的主干预训练,请在 config/config.yaml 中的 ['model']['g_name'] 下也指定它。

预训练模型

这里的预训练模型链接已经失效,无法下载。

可以参考博主之前上传的同样的权重,如下。

deblurGANV2预训练模型.zip-算法与数据结构文档类资源-CSDN下载

数据集G型D型损失类型PSNR/SSIM关联
GoPro 测试数据集InceptionResNet-v2双甘拉根29.55/ 0.934fpn_inception.h5
移动网络双甘拉根28.17/ 0.925fpn_mobilenet.h5
MobileNet-DSC双甘拉根28.03/ 0.922
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
DeblurGAN-v2 是一种图像去模糊的深度学习模型,可用于将模糊图像转换为清晰图像。在该模型中,使用了超像素技术来提高去模糊的效果。下面是利用超像素优化DeblurGAN-v2的PyTorch代码: 首先,需要安装以下依赖库: ``` pip install opencv-python pip install scikit-image pip install numpy pip install torch pip install torchvision pip install pydensecrf ``` 然后,加载DeblurGAN-v2模型和测试图像,并生成超像素: ```python import cv2 import torch import numpy as np from skimage.segmentation import slic from skimage.segmentation import mark_boundaries from skimage.color import rgb2gray from models.networks import define_G from options.test_options import TestOptions from util import util from pydensecrf.densecrf import DenseCRF2D # 加载模型 opt = TestOptions().parse() opt.nThreads = 1 opt.batchSize = 1 opt.serial_batches = True opt.no_flip = True model = define_G(opt) util.load_checkpoint(model, opt.pretrained) # 加载测试图像 img_path = 'path/to/test/image' img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w, c = img.shape # 生成超像素 segments = slic(img, n_segments=100, sigma=5, compactness=10) ``` 接下来,将每个超像素作为输入,运行DeblurGAN-v2模型进行去模糊: ```python # 对每个超像素进行去模糊 result = np.zeros((h, w, c), dtype=np.float32) for i in np.unique(segments): mask = (segments == i).astype(np.uint8) masked_img = cv2.bitwise_and(img, img, mask=mask) if np.sum(mask) > 0: masked_img = masked_img[np.newaxis, :, :, :] masked_img = torch.from_numpy(masked_img.transpose((0, 3, 1, 2))).float() with torch.no_grad(): output = model(masked_img) output = output.cpu().numpy() output = output.transpose((0, 2, 3, 1)) output = np.squeeze(output) result += output * mask[:, :, np.newaxis] # 对结果进行后处理 result /= 255.0 result = np.clip(result, 0, 1) result = (result * 255).astype(np.uint8) ``` 最后,使用密集条件随机场(DenseCRF)算法对结果进行后处理,以进一步提高去模糊的效果: ```python # 使用DenseCRF算法进行后处理 d = DenseCRF2D(w, h, 2) result_softmax = np.stack([result, 255 - result], axis=0) result_softmax = result_softmax.astype(np.float32) / 255.0 unary = -np.log(result_softmax) unary = unary.reshape((2, -1)) d.setUnaryEnergy(unary) d.addPairwiseGaussian(sxy=5, compat=3) d.addPairwiseBilateral(sxy=20, srgb=3, rgbim=img, compat=10) q = d.inference(5) q = np.argmax(np.array(q), axis=0).reshape((h, w)) result = q * 255 ``` 完整代码如下: ```python import cv2 import torch import numpy as np from skimage.segmentation import slic from skimage.segmentation import mark_boundaries from skimage.color import rgb2gray from models.networks import define_G from options.test_options import TestOptions from util import util from pydensecrf.densecrf import DenseCRF2D # 加载模型 opt = TestOptions().parse() opt.nThreads = 1 opt.batchSize = 1 opt.serial_batches = True opt.no_flip = True model = define_G(opt) util.load_checkpoint(model, opt.pretrained) # 加载测试图像 img_path = 'path/to/test/image' img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w, c = img.shape # 生成超像素 segments = slic(img, n_segments=100, sigma=5, compactness=10) # 对每个超像素进行去模糊 result = np.zeros((h, w, c), dtype=np.float32) for i in np.unique(segments): mask = (segments == i).astype(np.uint8) masked_img = cv2.bitwise_and(img, img, mask=mask) if np.sum(mask) > 0: masked_img = masked_img[np.newaxis, :, :, :] masked_img = torch.from_numpy(masked_img.transpose((0, 3, 1, 2))).float() with torch.no_grad(): output = model(masked_img) output = output.cpu().numpy() output = output.transpose((0, 2, 3, 1)) output = np.squeeze(output) result += output * mask[:, :, np.newaxis] # 对结果进行后处理 result /= 255.0 result = np.clip(result, 0, 1) result = (result * 255).astype(np.uint8) # 使用DenseCRF算法进行后处理 d = DenseCRF2D(w, h, 2) result_softmax = np.stack([result, 255 - result], axis=0) result_softmax = result_softmax.astype(np.float32) / 255.0 unary = -np.log(result_softmax) unary = unary.reshape((2, -1)) d.setUnaryEnergy(unary) d.addPairwiseGaussian(sxy=5, compat=3) d.addPairwiseBilateral(sxy=20, srgb=3, rgbim=img, compat=10) q = d.inference(5) q = np.argmax(np.array(q), axis=0).reshape((h, w)) result = q * 255 # 显示结果 result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) cv2.imshow('result', result) cv2.waitKey(0) cv2.destroyAllWindows() ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

武大人民泌外I科人工智能团队

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值