图像去模糊:跑通DeblurGAN-v2

目录

一.环境的配置

二.跑通测试predict.py

三.跑通训练train.py

1.数据准备

2.数据增强方式

3.加载预训练模型

4.模型训练结果存储问题


工程:https://github.com/VITA-Group/DeblurGANv2

一.环境的配置

直接使用python train.py缺什么库,就安装什么库;

二.跑通测试predict.py

需要设置以下参数:

--img_pattern
/media/XXX/test/LR/2021-01-28_11-21-04_white.jpg
--mask_pattern
None
--weights_path
/media/XXX/deblur/DeblurGANv2-master/weights/fpn_inception.h5
--out_dir
submit/
--side_by_side
False
--video
False

三.跑通训练train.py

配置config中的config.yaml参数:

---
project: deblur_gan
experiment_desc: fpn #日志存储文件夹

train:
  files_a: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/train/LR/*.jpg #&FILES_A /datasets/my_dataset/**/*.jpg #low quality/blury images
  files_b: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/train/HR/*.jpg #*FILES_A #clean files
  size: &SIZE 256
  crop: random #裁剪方式选择,备选项为:center
  preload: &PRELOAD false
  preload_size: &PRELOAD_SIZE 0
  bounds: [0, .9]
  scope: geometric
  corrupt: &CORRUPT
    - name: cutout
      prob: 0.5 #数据增强概率
      num_holes: 3
      max_h_size: 25
      max_w_size: 25
    - name: jpeg #增强方式选择,配合aug.py 函数def _resolve_aug_fn(name)中查看挑选需要的增强方式
      quality_lower: 70
      quality_upper: 90
    - name: motion_blur #增强方式选择,配合aug.py 函数def _resolve_aug_fn(name)中查看挑选需要的增强方式
    - name: median_blur
    - name: gamma
    - name: rgb_shift
    - name: hsv_shift
    - name: sharpen

val:
  files_a: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/test/LR/*.jpg #*FILES_A
  files_b: /media/XXX/7292a4b1-2584-4296-8caf-eb9788c2ffb9/data/deblur/deblurGANv2/20211209/test/HR/*.jpg #*FILES_A
  size: *SIZE
  scope: geometric
  crop: center
  preload: *PRELOAD
  preload_size: *PRELOAD_SIZE
  bounds: [.9, 1]
  corrupt: *CORRUPT

phase: train
warmup_num: 3
model:
  g_name: fpn_inception
  blocks: 9
  d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale
  d_layers: 3
  content_loss: perceptual
  adv_lambda: 0.001
  disc_loss: wgan-gp
  learn_residual: True
  norm_layer: instance
  dropout: True

num_epochs: 200
train_batches_per_epoch: 1000 #训练进度条长度
val_batches_per_epoch: 100 #验证时进度条长度
batch_size: 1
image_size: [256, 256] #图像推理尺寸

optimizer:
  name: adam
  lr: 0.0001
scheduler:
  name: linear
  start_epoch: 50
  min_lr: 0.0000001

1.数据准备

注意:训练时数据推理尺寸为256*256,为了防止图像变形,所以使用的训练连样本都是宽高相等的图片;

准备自己的数据时,HR和LR图像的尺寸要相等,这个有别于超分辨率准备的数据,当HR和LR图像尺寸不相等时,模型训练的精度会一直起不来,本人训练时PSNR一直在16徘徊,跑了一晚上才醒悟(训练有问题啊);

2.数据增强方式

该项目中使用的是albumentations库,结合config中的config.yaml进行参数配置;

目前有这么多种增强方式,可修改源码

选其一:aug.py def get_transform中
albu.HorizontalFlip(always_apply=True), #左右翻转
albu.ShiftScaleRotate(always_apply=True),#随机仿射变换
albu.Transpose(always_apply=True),#转置
albu.OpticalDistortion(always_apply=True),#非刚体变换方法
albu.ElasticTransform(always_apply=True)#非刚体变换方法
albu.RandomCrop#随机裁剪
albu.CenterCrop#中心裁剪
选其一:aug.py def _resolve_aug_fn(name):
albu.Cutout,#随机擦除
albu.RGBShift,#对图像RGB的每个通道随机移动值
albu.HueSaturationValue,#随机更改图像的颜色,饱和度和值
albu.MotionBlur,#
albu.MedianBlur,
albu.RandomSnow,
albu.RandomShadow,
albu.RandomFog,#随机雾化
albu.RandomBrightnessContrast,
albu.RandomGamma,#随机灰度系数
albu.RandomSunFlare,
albu.Sharpen,
albu.ImageCompression,
albu.ToGray,
albu.Downscale,

3.加载预训练模型

train.py中的def _init_params(self):
self.criterionG, criterionD = get_loss(self.config['model'])
self.netG, netD = get_nets(self.config['model'])
self.netG.load_state_dict(torch.load("weights/fpn_inception.h5", map_location='cpu')['model'])

4.模型训练结果存储问题

按照原工程中的设置,日志文件存储在fpn文件夹下;

训练模型只存储最新一个和最好的一个模型,而且是存储在工程根目录下,没有另起一个文件夹存储,可以修改def train(self)中的代码:

原代码为:

if self.metric_counter.update_best_model():
    torch.save({'model': self.netG.state_dict()}, 
         'best_{}.h5'.format(self.config['experiment_desc']))
    torch.save({'model': self.netG.state_dict()
            }, 'last_{}.h5'.format(self.config['experiment_desc']))

修改为:

if self.metric_counter.update_best_model():
    torch.save({'model': self.netG.state_dict()},self.config['experiment_desc']+'/best_{}.h5'.format(self.config['experiment_desc']))
    torch.save({'model': self.netG.state_dict()}, self.config['experiment_desc']+'/last_{}.h5'.format(self.config['experiment_desc']))

if epoch // 50:
    torch.save({'model': self.netG.state_dict()}, self.config['experiment_desc']+'/epoch_{}.h5'.format(epoch))

其他链接:

图像去模糊之DeblurGAN-v2_年轻即出发,-CSDN博客_deblurganv2

  • 6
    点赞
  • 58
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 16
    评论
DeblurGAN-v2 是一种图像模糊的深度学习模型,可用于将模糊图像转换为清晰图像。在该模型中,使用了超像素技术来提高去模糊的效果。下面是利用超像素优化DeblurGAN-v2的PyTorch代码: 首先,需要安装以下依赖库: ``` pip install opencv-python pip install scikit-image pip install numpy pip install torch pip install torchvision pip install pydensecrf ``` 然后,加载DeblurGAN-v2模型和测试图像,并生成超像素: ```python import cv2 import torch import numpy as np from skimage.segmentation import slic from skimage.segmentation import mark_boundaries from skimage.color import rgb2gray from models.networks import define_G from options.test_options import TestOptions from util import util from pydensecrf.densecrf import DenseCRF2D # 加载模型 opt = TestOptions().parse() opt.nThreads = 1 opt.batchSize = 1 opt.serial_batches = True opt.no_flip = True model = define_G(opt) util.load_checkpoint(model, opt.pretrained) # 加载测试图像 img_path = 'path/to/test/image' img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w, c = img.shape # 生成超像素 segments = slic(img, n_segments=100, sigma=5, compactness=10) ``` 接下来,将每个超像素作为输入,运行DeblurGAN-v2模型进行去模糊: ```python # 对每个超像素进行去模糊 result = np.zeros((h, w, c), dtype=np.float32) for i in np.unique(segments): mask = (segments == i).astype(np.uint8) masked_img = cv2.bitwise_and(img, img, mask=mask) if np.sum(mask) > 0: masked_img = masked_img[np.newaxis, :, :, :] masked_img = torch.from_numpy(masked_img.transpose((0, 3, 1, 2))).float() with torch.no_grad(): output = model(masked_img) output = output.cpu().numpy() output = output.transpose((0, 2, 3, 1)) output = np.squeeze(output) result += output * mask[:, :, np.newaxis] # 对结果进行后处理 result /= 255.0 result = np.clip(result, 0, 1) result = (result * 255).astype(np.uint8) ``` 最后,使用密集条件随机场(DenseCRF)算法对结果进行后处理,以进一步提高去模糊的效果: ```python # 使用DenseCRF算法进行后处理 d = DenseCRF2D(w, h, 2) result_softmax = np.stack([result, 255 - result], axis=0) result_softmax = result_softmax.astype(np.float32) / 255.0 unary = -np.log(result_softmax) unary = unary.reshape((2, -1)) d.setUnaryEnergy(unary) d.addPairwiseGaussian(sxy=5, compat=3) d.addPairwiseBilateral(sxy=20, srgb=3, rgbim=img, compat=10) q = d.inference(5) q = np.argmax(np.array(q), axis=0).reshape((h, w)) result = q * 255 ``` 完整代码如下: ```python import cv2 import torch import numpy as np from skimage.segmentation import slic from skimage.segmentation import mark_boundaries from skimage.color import rgb2gray from models.networks import define_G from options.test_options import TestOptions from util import util from pydensecrf.densecrf import DenseCRF2D # 加载模型 opt = TestOptions().parse() opt.nThreads = 1 opt.batchSize = 1 opt.serial_batches = True opt.no_flip = True model = define_G(opt) util.load_checkpoint(model, opt.pretrained) # 加载测试图像 img_path = 'path/to/test/image' img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w, c = img.shape # 生成超像素 segments = slic(img, n_segments=100, sigma=5, compactness=10) # 对每个超像素进行去模糊 result = np.zeros((h, w, c), dtype=np.float32) for i in np.unique(segments): mask = (segments == i).astype(np.uint8) masked_img = cv2.bitwise_and(img, img, mask=mask) if np.sum(mask) > 0: masked_img = masked_img[np.newaxis, :, :, :] masked_img = torch.from_numpy(masked_img.transpose((0, 3, 1, 2))).float() with torch.no_grad(): output = model(masked_img) output = output.cpu().numpy() output = output.transpose((0, 2, 3, 1)) output = np.squeeze(output) result += output * mask[:, :, np.newaxis] # 对结果进行后处理 result /= 255.0 result = np.clip(result, 0, 1) result = (result * 255).astype(np.uint8) # 使用DenseCRF算法进行后处理 d = DenseCRF2D(w, h, 2) result_softmax = np.stack([result, 255 - result], axis=0) result_softmax = result_softmax.astype(np.float32) / 255.0 unary = -np.log(result_softmax) unary = unary.reshape((2, -1)) d.setUnaryEnergy(unary) d.addPairwiseGaussian(sxy=5, compat=3) d.addPairwiseBilateral(sxy=20, srgb=3, rgbim=img, compat=10) q = d.inference(5) q = np.argmax(np.array(q), axis=0).reshape((h, w)) result = q * 255 # 显示结果 result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) cv2.imshow('result', result) cv2.waitKey(0) cv2.destroyAllWindows() ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猫猫与橙子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值