code:https://github.com/swz30/MPRNet
paper:https://openaccess.thecvf.com/content/CVPR2021/papers/Zamir_Multi-Stage_Progressive_Image_Restoration_CVPR_2021_paper.pdf
1.不是吧,还有人想着一个环境跑所有代码?废话少说。搭建环境
conda create -n pytorch1 python=3.7
conda activate pytorch1
conda install pytorch=1.1 torchvision=0.3 cudatoolkit=9.0
pip install matplotlib scikit-image opencv-python yacs joblib natsort h5py tqdm
#低版本的pytorch安装不上,其他版本的pytorch也可以。
2.download code,进入项目目录下进行编译
cd pytorch-gradual-warmup-lr; python setup.py install; cd ..
3.如果你想复现author的结果,直接下载预训练模型跟数据集。执行
deblurring预训练模型
python demo.py --task Task_Name --input_dir path_to_images --result_dir save_images_here
#eg
python demo.py --task Deblurring --input_dir ./samples/input/ --result_dir ./samples/output/
#以去模糊为例,可下载对应的预训练模型跟数据集测试其他任务:去雨或去噪
放张图对比一下:模糊图与去完模糊后的图
4.训练自己的数据集,首先去到deblurring项目下,数据格式按GOPro数据集准备。
###############
##
####
#一张卡就改成你自己的卡
GPU: [0,1,2,3,4,5,6,7]
VERBOSE: True
MODEL:
MODE: 'Deblurring'
SESSION: 'MPRNet'
# Optimization arguments.\
#按自己的数据集修改batch_size、num_epochs
OPTIM:
BATCH_SIZE: 16
NUM_EPOCHS: 3000
# NEPOCH_DECAY: [10]
LR_INITIAL: 2e-4
LR_MIN: 1e-6
# BETA1: 0.9
TRAINING:
VAL_AFTER_EVERY: 20
RESUME: False
TRAIN_PS: 256
VAL_PS: 256
#xxx是你自己的数据集
TRAIN_DIR: './Datasets/xxx/train' # path to training data
VAL_DIR: './Datasets/xxx/test' # path to validation data
SAVE_DIR: './checkpoints' # path to save models and images
# SAVE_IMAGES: False
5.执行train.py,开始训练了是不是很简单
6.没一会儿就训练完了卧槽,测试一下。对test.py添加参数
arser = argparse.ArgumentParser(description='Image Deblurring using MPRNet')
parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images')
#测试结果保存路径
parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
#训练完保存的模型。训练时配置文件就已经指定了模型保存路径,保存了很多模型选best的
parser.add_argument('--weights',
default='./checkpoints/model_deblurring.pth', type=str, help='Path to weights')
#xxx是你自己的数据集噢
parser.add_argument('--dataset', default='XXX', type=str, help='Test Dataset') # ['GoPro', 'HIDE', 'RealBlur_J', 'RealBlur_R']
#没有多卡就默认gpu id为0
parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
args = parser.parse_args()
7.自己的数据集就不展示了,涉密。PSNR、SSIM用的是matlab代码,当然也有python的可自行选择。
import os
import numpy as np
from glob import glob
from natsort import natsorted
from skimage import io
import cv2
from skimage.metrics import structural_similarity
from tqdm import tqdm
import concurrent.futures
def image_align(deblurred, gt):
# this function is based on kohler evaluation code
z = deblurred
c = np.ones_like(z)
x = gt
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
warp_mode = cv2.MOTION_HOMOGRAPHY
warp_matrix = np.eye(3, 3, dtype=np.float32)
# Specify the number of iterations.
number_of_iterations = 100
termination_eps = 0
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
number_of_iterations, termination_eps)
# Run the ECC algorithm. The results are stored in warp_matrix.
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
target_shape = x.shape
shift = warp_matrix
zr = cv2.warpPerspective(
zs,
warp_matrix,
(target_shape[1], target_shape[0]),
flags=cv2.INTER_CUBIC+ cv2.WARP_INVERSE_MAP,
borderMode=cv2.BORDER_REFLECT)
cr = cv2.warpPerspective(
np.ones_like(zs, dtype='float32'),
warp_matrix,
(target_shape[1], target_shape[0]),
flags=cv2.INTER_NEAREST+ cv2.WARP_INVERSE_MAP,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0)
zr = zr * cr
xr = x * cr
return zr, xr, cr, shift
def compute_psnr(image_true, image_test, image_mask, data_range=None):
# this function is based on skimage.metrics.peak_signal_noise_ratio
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
return 10 * np.log10((data_range ** 2) / err)
def compute_ssim(tar_img, prd_img, cr1):
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False, data_range = 1.0, full=True)
ssim_map = ssim_map * cr1
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
win_size = 2 * r + 1
pad = (win_size - 1) // 2
ssim = ssim_map[pad:-pad,pad:-pad,:]
crop_cr1 = cr1[pad:-pad,pad:-pad,:]
ssim = ssim.sum(axis=0).sum(axis=0)/crop_cr1.sum(axis=0).sum(axis=0)
ssim = np.mean(ssim)
return ssim
def proc(filename):
tar,prd = filename
tar_img = io.imread(tar)
prd_img = io.imread(prd)
tar_img = tar_img.astype(np.float32)/255.0
prd_img = prd_img.astype(np.float32)/255.0
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
SSIM = compute_ssim(tar_img, prd_img, cr1)
return (PSNR,SSIM)
datasets = ['RealBlur_J', 'RealBlur_R']
for dataset in datasets:
file_path = os.path.join('results' , dataset)
gt_path = os.path.join('Datasets', dataset, 'test', 'target')
path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))
assert len(path_list) != 0, "Predicted files not found"
assert len(gt_list) != 0, "Target files not found"
psnr, ssim = [], []
img_files =[(i, j) for i,j in zip(gt_list,path_list)]
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
psnr.append(PSNR_SSIM[0])
ssim.append(PSNR_SSIM[1])
avg_psnr = sum(psnr)/len(psnr)
avg_ssim = sum(ssim)/len(ssim)
print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
https://github.com/swz30/MPRNet