MPRNet 训练自己的数据集

code:https://github.com/swz30/MPRNet
paper:https://openaccess.thecvf.com/content/CVPR2021/papers/Zamir_Multi-Stage_Progressive_Image_Restoration_CVPR_2021_paper.pdf
1.不是吧,还有人想着一个环境跑所有代码?废话少说。搭建环境

conda create -n pytorch1 python=3.7
conda activate pytorch1
conda install pytorch=1.1 torchvision=0.3 cudatoolkit=9.0 
pip install matplotlib scikit-image opencv-python yacs joblib natsort h5py tqdm
#低版本的pytorch安装不上,其他版本的pytorch也可以。

2.download code,进入项目目录下进行编译

cd pytorch-gradual-warmup-lr; python setup.py install; cd ..

3.如果你想复现author的结果,直接下载预训练模型跟数据集。执行
deblurring预训练模型

python demo.py --task Task_Name --input_dir path_to_images --result_dir save_images_here
#eg
python demo.py --task Deblurring --input_dir ./samples/input/ --result_dir ./samples/output/
#以去模糊为例,可下载对应的预训练模型跟数据集测试其他任务:去雨或去噪

放张图对比一下:模糊图与去完模糊后的图

在这里插入图片描述
在这里插入图片描述
4.训练自己的数据集,首先去到deblurring项目下,数据格式按GOPro数据集准备。

###############
## 
####
#一张卡就改成你自己的卡
GPU: [0,1,2,3,4,5,6,7]

VERBOSE: True

MODEL:
  MODE: 'Deblurring'
  SESSION: 'MPRNet'

# Optimization arguments.\
#按自己的数据集修改batch_size、num_epochs
OPTIM:
  BATCH_SIZE: 16
  NUM_EPOCHS: 3000
  # NEPOCH_DECAY: [10]
  LR_INITIAL: 2e-4
  LR_MIN: 1e-6
  # BETA1: 0.9

TRAINING:
  VAL_AFTER_EVERY: 20
  RESUME: False
  TRAIN_PS: 256
  VAL_PS: 256
  #xxx是你自己的数据集
  TRAIN_DIR: './Datasets/xxx/train' # path to training data
  VAL_DIR: './Datasets/xxx/test'    # path to validation data
  SAVE_DIR: './checkpoints'     # path to save models and images
  # SAVE_IMAGES: False

5.执行train.py,开始训练了是不是很简单

在这里插入图片描述
6.没一会儿就训练完了卧槽,测试一下。对test.py添加参数

arser = argparse.ArgumentParser(description='Image Deblurring using MPRNet')

parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images')
#测试结果保存路径
parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
#训练完保存的模型。训练时配置文件就已经指定了模型保存路径,保存了很多模型选best的
parser.add_argument('--weights', 
default='./checkpoints/model_deblurring.pth', type=str, help='Path to weights')
#xxx是你自己的数据集噢
parser.add_argument('--dataset', default='XXX', type=str, help='Test Dataset') # ['GoPro', 'HIDE', 'RealBlur_J', 'RealBlur_R']
#没有多卡就默认gpu id为0
parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')

args = parser.parse_args()

7.自己的数据集就不展示了,涉密。PSNR、SSIM用的是matlab代码,当然也有python的可自行选择。

import os
import numpy as np
from glob import glob
from natsort import natsorted
from skimage import io
import cv2
from skimage.metrics import structural_similarity
from tqdm import tqdm
import concurrent.futures

def image_align(deblurred, gt):
  # this function is based on kohler evaluation code
  z = deblurred
  c = np.ones_like(z)
  x = gt

  zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching

  warp_mode = cv2.MOTION_HOMOGRAPHY
  warp_matrix = np.eye(3, 3, dtype=np.float32)

  # Specify the number of iterations.
  number_of_iterations = 100

  termination_eps = 0

  criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
              number_of_iterations, termination_eps)

  # Run the ECC algorithm. The results are stored in warp_matrix.
  (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)

  target_shape = x.shape
  shift = warp_matrix

  zr = cv2.warpPerspective(
    zs,
    warp_matrix,
    (target_shape[1], target_shape[0]),
    flags=cv2.INTER_CUBIC+ cv2.WARP_INVERSE_MAP,
    borderMode=cv2.BORDER_REFLECT)

  cr = cv2.warpPerspective(
    np.ones_like(zs, dtype='float32'),
    warp_matrix,
    (target_shape[1], target_shape[0]),
    flags=cv2.INTER_NEAREST+ cv2.WARP_INVERSE_MAP,
    borderMode=cv2.BORDER_CONSTANT,
    borderValue=0)

  zr = zr * cr
  xr = x * cr

  return zr, xr, cr, shift

def compute_psnr(image_true, image_test, image_mask, data_range=None):
  # this function is based on skimage.metrics.peak_signal_noise_ratio
  err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
  return 10 * np.log10((data_range ** 2) / err)


def compute_ssim(tar_img, prd_img, cr1):
    ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False, data_range = 1.0, full=True)
    ssim_map = ssim_map * cr1
    r = int(3.5 * 1.5 + 0.5)  # radius as in ndimage
    win_size = 2 * r + 1
    pad = (win_size - 1) // 2
    ssim = ssim_map[pad:-pad,pad:-pad,:]
    crop_cr1 = cr1[pad:-pad,pad:-pad,:]
    ssim = ssim.sum(axis=0).sum(axis=0)/crop_cr1.sum(axis=0).sum(axis=0)
    ssim = np.mean(ssim)
    return ssim

def proc(filename):
    tar,prd = filename
    tar_img = io.imread(tar)
    prd_img = io.imread(prd)
    
    tar_img = tar_img.astype(np.float32)/255.0
    prd_img = prd_img.astype(np.float32)/255.0
    
    prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)

    PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
    SSIM = compute_ssim(tar_img, prd_img, cr1)
    return (PSNR,SSIM)

datasets = ['RealBlur_J', 'RealBlur_R']

for dataset in datasets:

    file_path = os.path.join('results' , dataset)
    gt_path = os.path.join('Datasets', dataset, 'test', 'target')

    path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
    gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))

    assert len(path_list) != 0, "Predicted files not found"
    assert len(gt_list) != 0, "Target files not found"

    psnr, ssim = [], []
    img_files =[(i, j) for i,j in zip(gt_list,path_list)]
    with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
        for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
            psnr.append(PSNR_SSIM[0])
            ssim.append(PSNR_SSIM[1])

    avg_psnr = sum(psnr)/len(psnr)
    avg_ssim = sum(ssim)/len(ssim)

    print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))

https://github.com/swz30/MPRNet

评论 67
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

VisionX Lab

你的鼓励将是我更新的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值