【图像去噪】论文复现:新手入门必看!DnCNN的Pytorch源码训练测试全流程解析!为源码做详细注释!补充DnCNN-B和DnCNN-3的模型训练和测试!附各种情况下训练好的模型权重文件!

请先看【专栏介绍文章】:【图像去噪(Image Denoising)】关于【图像去噪】专栏的相关说明,包含适配人群、专栏简介、专栏亮点、阅读方法、定价理由、品质承诺、关于更新、去噪概述、文章目录、资料汇总、问题汇总(更新中)

源码只提供了noise level为25的DnCNN-S模型文件。本文末尾有完整代码和训练好的σ=15,25,50的DnCNN-S、σ ∈ [0, 55]的DnCNN-B和CDnCNN-B、DnCNN-3共6个模型文件!读者可以自行下载!

本文亮点:

  • 以官方Pytorch源代码为基础,在DnCNN-S的基础上,增添DnCNN-B/CDnCNN-B、DnCNN-3模型训练和测试复现,代码注释非常详细,无论是科研还是应用,新手小白都能看懂,学习阅读毫无压力,去噪入门必备,适用于去噪、超分、JPEG去块任务
  • 提供新增后的完整代码和训练好的模型权重文件,模型性能与论文中近似,可不训练直接测试;
  • 理论和源码结合,进一步加深理解算法原理、明确训练和测试流程
  • 更换路径和相关参数即可训练自己的图像数据集
  • 几乎实现论文中全部的图表,相当于整个工作自己做了一遍,非常全面


前言

论文题目:Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising —— 除了高斯去噪器:深度CNN残差学习图像去噪

论文地址:Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

论文源码:https://github.com/cszn/DnCNN

对应的论文精读:【图像去噪】论文精读:Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising(DnCNN)

请先看DnCNN的论文精读,先大概了解算法原理,再读本文的代码,事半功倍!

DnCNN的Pytorch版本代码在源代码项目中的TrainingCodes/dncnn_pytorch/路径下。下载好代码跟着我一起复现学习吧!建议直接下载增添内容后的完整代码!

准备工作:BSD400数据集作为训练集,BSD68和Set12作为测试集。请读者自行搜索并下载好数据集,尽量在数据集所在论文提供的官方途径下载,以免图像质量不同造成评价指标误差。

本文使用代码:
https://github.com/cszn/DnCNN/tree/master/TrainingCodes/dncnn_pytorch

训练集:https://github.com/cszn/DnCNN/tree/master/TrainingCodes/DnCNN_TrainingCodes_v1.0/data/Train400

源码项目文件说明

在这里插入图片描述
下载好源码后,将BSD400数据集放到data目录下,并改名为Train400;在data目录下新建Test文件夹,将BSD68和Set12放入其中,并将BSD68改为Set68。(这样修改的目的是与源码对应,修改完直接运行代码即可。当然,也可以改源码中对应位置路径,详细流程只墨迹这一次,以后的复现文章不会如此啰嗦。

文件相关说明

  • data文件夹存放训练集和测试集
  • models文件夹存放训练好的模型
  • results文件夹存放去噪结果(可选是否保存)
  • data_generator.py:制作数据集(切块,转成Tensor)
  • main_test.py:在测试集上测试模型,输出去噪后图像,计算测试集上的平均PSNR和SSIM
  • main_train.py:训练DnCNN

使用方式

  1. 放置好数据集
  2. 运行main_train.py训练
  3. 运行main_test.py测试

训练和测试不同模型请修改对应的参数。无论是windows下还是linux下,建议修改parser的默认值为你所需要的值后再去跑,避免命令输错。

数据预处理

本节对应data_generator.py。

提炼整理论文IV-A与训练集和测试集有关信息

  • 训练集:BSD400,每张图像是180×180的灰度图(灰度图去噪);BSD432(彩色图去噪)
  • mini-batch:128
  • patch_size:40×40,128×1600(DnCNN-S);50×50,128×3000(DnCNN-B和CDnCNN-B);50×50,128×8000(DnCNN-3)
  • 数据增强:旋转/翻转,只在训练DnCNN-3时使用(可能训练其他模型也用了数据增强,但是在DnCNN-3那一段阐述的

源码中的实现以及与论文所述的区别

  • 数据增强:源码还用了缩放
  • 对于DnCNN-S,patch_size=40,stride=10,裁剪后总块数为BSD400图像块数量为238336,与128 * 1600最为接近。步长只能取整数,除非人工减少块数,否则在固定数据增强手段后,无法与论文中的裁剪块数完全一致。
  • 根据上述原则,DnCNN-B的patch_size, stride = 50, 7,BSD400图像块数量为386688,与128 * 3000接近;DnCNN-3的patch_size, stride = 50, 4,BSD400图像块数量为1146752,与128 * 8000接近;

对BSD400数据集图像切块以及转成Tensor的实现思路:

在对应的缩放倍数图像下,以块大小和移动步长扫描整个图像,直至每张图像切块完毕,应用随机数据增强,并将所有的图像块放到data列表中。

图像是cv2读取,为numpy类型,根据图像格式,将data制作成(n,h,w,c)的形式,生成与data中元素相同形状的噪声,得到加噪前后的图像Tensor,以供Pytorch的DataLoader使用。

data_generator.py加注释后的代码

# -*- coding: utf-8 -*-

# =============================================================================
#  @article{zhang2017beyond,
#    title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
#    author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
#    journal={IEEE Transactions on Image Processing},
#    year={2017},
#    volume={26}, 
#    number={7}, 
#    pages={3142-3155}, 
#  }
# by Kai Zhang (08/2018)
# cskaizhang@gmail.com
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# no need to run this code separately


import glob
import cv2
import numpy as np
# from multiprocessing import Pool
from torch.utils.data import Dataset
import torch

# 固定数据增强的情况下,块大小为40,步长为10,BSD400图像块数量为238336,与128 * 1600接近
# 固定数据增强的情况下,块大小为50,步长为7,BSD400图像块数量为386688,与128 * 3000接近
# 固定数据增强的情况下,块大小为50,步长为4,BSD400图像块数量为1146752,与128 * 8000接近
patch_size, stride = 40, 10     # 图像块大小,步长
aug_times = 1   # 每个图像块增强次数
scales = [1, 0.9, 0.8, 0.7] # 数据增强缩放
batch_size = 128    # mini-batch大小

# 封装带噪声和不带噪声的图像块
class DenoisingDataset(Dataset):
    """Dataset wrapping tensors.
    Arguments:
        xs (Tensor): clean image patches, (n,c,h,w)四维张量,在训练前制作好,n是总图像块数量
        sigma: noise level, e.g., 25
    """
    def __init__(self, xs, sigma):
        super(DenoisingDataset, self).__init__()
        self.xs = xs
        self.sigma = sigma

    def __getitem__(self, index):
        batch_x = self.xs[index]  # 每个图像块
        # 噪声生成:生成与batch_x相同形状满足标准正太分布的张量,然后按元素乘[0,255]像素范围内的噪声标准差
        noise = torch.randn(batch_x.size()).mul_(self.sigma/255.0)
        batch_y = batch_x + noise
        return batch_y, batch_x

    def __len__(self):
        return self.xs.size(0)


# 展示图像块
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

# 数据增强选项
def data_aug(img, mode=0):
    # data augmentation
    if mode == 0:
        return img
    elif mode == 1:
        return np.flipud(img)
    elif mode == 2:
        return np.rot90(img)
    elif mode == 3:
        return np.flipud(np.rot90(img))
    elif mode == 4:
        return np.rot90(img, k=2)
    elif mode == 5:
        return np.flipud(np.rot90(img, k=2))
    elif mode == 6:
        return np.rot90(img, k=3)
    elif mode == 7:
        return np.flipud(np.rot90(img, k=3))

# 生成图像块
def gen_patches(file_name):
    # get multiscale patches from a single image
    img = cv2.imread(file_name, 0)  # gray scale
    h, w = img.shape
    patches = []
    for s in scales: # 每个缩放倍数下
        h_scaled, w_scaled = int(h*s), int(w*s)
        img_scaled = cv2.resize(img, (h_scaled, w_scaled), interpolation=cv2.INTER_CUBIC)
        # extract patches, 缩放后按块大小和步长裁剪图像块,并应用随机数据增强
        for i in range(0, h_scaled-patch_size+1, stride):
            for j in range(0, w_scaled-patch_size+1, stride):
                x = img_scaled[i:i+patch_size, j:j+patch_size]
                for k in range(0, aug_times):
                    x_aug = data_aug(x, mode=np.random.randint(0, 8))
                    patches.append(x_aug)
    return patches

# 得到训练集中所有的图像块
def datagenerator(data_dir='data/Train400', verbose=False):
    # generate clean patches from a dataset
    file_list = glob.glob(data_dir+'/*.png')  # get name list of all .png files
    # initrialize
    data = []
    # generate patches
    for i in range(len(file_list)):
        patches = gen_patches(file_list[i])
        for patch in patches:    
            data.append(patch) # data列表中每个元素是一个图像块
        if verbose:
            print(str(i+1) + '/' + str(len(file_list)) + ' is done ^_^')
    data = np.array(data, dtype='uint8') # 转成numpy,(n,h,w)
    data = np.expand_dims(data, axis=3) # (n,h,w,1),因为网络输入输出通道都是1,直接添加一维为通道数,满足DataLoader的tensor格式
    discard_n = len(data)-len(data)//batch_size*batch_size  # 多余的样本数量
    data = np.delete(data, range(discard_n), axis=0)    # 删除多余的,保证整除batch_size
    print('^_^-training data finished-^_^')
    return data # (n,h,w,1)


if __name__ == '__main__': 

    data = datagenerator(data_dir='data/Train400')
    print(len(data))
    print(data.shape)
    # show(data[100])

#    print('Shape of result = ' + str(res.shape))
#    print('Saving data...')
#    if not os.path.exists(save_dir):
#            os.mkdir(save_dir)
#    np.save(save_dir+'clean_patches.npy', res)
#    print('Done.')       

各个函数的作用以及相关细节请见代码注释。

DnCNN模型训练

本节对应main_train.py。主要包含三个部分:DnCNN模型结构定义、损失函数定义、训练参数与训练过程

DnCNN模型结构

在这里插入图片描述
论文图1为DnCNN的网络结构

  • 输入:噪声图像
  • 第一层:Conv+ReLU
  • 中间若干层:Conv+BN+ReLU
  • 最后一层:Conv
  • 输出:噪声

论文中模型相关细节

  • 所有卷积层卷积核为3×3
  • 特征数为64
  • 使用零填充减少边缘伪影(在源码中padding=1
  • DnCNN是学习噪声,但是我们的目的还是要去噪后的干净图像,所以在网络前向传播时,输出为带噪声图像减去噪声。
  • 网络深度:指输入和输出之间的层数,比如定义网络深度为17,则中间的Conv+BN+ReLU为15个。
  • 使用Kaiming初始化权重(论文中参考文献[34],实验参数设置部分提到

源码实现与论文描述基本一致,有一些细节需要注意:

  • BN层手动添加了eps和momentum
  • 后跟BN层的Conv没有使用bias,减少计算量
  • 权重初始化使用nn.init.orthogonal_的正交初始化,而没有使用nn.init.kaiming_normal_凯明初始化(对模型性能影响应该不大

DnCNN模型结构加注释后的代码

class DnCNN(nn.Module):
    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []

        # 第一层使用了偏置,中间层没使用,默认bias为True
        # 一般是跟BN的Conv的bias为False,因为对BN的计算没用,可以减少计算量。如果不考虑计算量,加不加都行
        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            # eps避免计算标准差分母为0,一般为1e-5(默认)或1e-4
            # momentum更新率,一般为0.9或0.95;默认为0.1,即只使用新批次的10%数据,之前的90%来自移动平均值
            # 切块的数据形式,BN的动量接近1会比较好
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x) 
        return y-out    

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

损失函数

在这里插入图片描述
损失函数是模型学习的指导,明确任务目标。根据论文公式(1),损失函数定义为残差图像与带噪声图像直接的平均均方误差,注意有个二分之一。换句话说,DnCNN是要学习噪声,而不是去噪后的图像。由于上述DnCNN模型定义的输出是预测的去噪后的图像,所以损失为input-target,即为噪声。学习最优噪声对应的模型参数,即为最优模型。

损失函数源码如下:

class sum_squared_error(_Loss):  # PyTorch 0.4.1
    """
    Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
    The backward is defined as: input-target
    """
    def __init__(self, size_average=None, reduce=None, reduction='sum'):
        super(sum_squared_error, self).__init__(size_average, reduce, reduction)

    def forward(self, input, target):
        # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
        return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)

训练参数与训练过程

论文中描述

  • 优化器SGD,weight decay为0.0001,momentum为0.9
  • mini-batch为128
  • epoch:50
  • lr:1e-1降到1e-4,每10个epoch降低10倍

Pytorch源码中的参数:

  • 优化器为Adam
  • epoch:180
  • lr:1e-3,每30个epoch降低20%

注:训练参数是Pytorch版本代码与论文描述差别最大的地方。究其原因,应该是Pytorch与matlab的不同导致,Pytorch中SGD的1e-1学习率太大,很容易梯度爆炸或者梯度消失。所以,在Pytorch版本中原作者将优化器改为Adam,学习率改为1e-3,可以较为平滑顺利的训练。我们先按论文作者给出的源码复现,后续再改参数实验。

题外话:不必太过纠结与论文中参数描述一致,重在对于源码的学习,熟悉训练和测试流程。

训练过程包含数据读取、模型定义、相关调用、模型保存、loss计算等,不再详述,请读者自行阅读源码,注释已经很详细了。

main_train.py加注释后的代码:

# -*- coding: utf-8 -*-

# PyTorch 0.4.1, https://pytorch.org/docs/stable/index.html

# =============================================================================
#  @article{zhang2017beyond,
#    title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
#    author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
#    journal={IEEE Transactions on Image Processing},
#    year={2017},
#    volume={26}, 
#    number={7}, 
#    pages={3142-3155}, 
#  }
# by Kai Zhang (08/2018)
# cskaizhang@gmail.com
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# run this to train the model

# =============================================================================
# For batch normalization layer, momentum should be a value from [0.1, 1] rather than the default 0.1. 
# The Gaussian noise output helps to stablize the batch normalization, thus a large momentum (e.g., 0.95) is preferred.
# =============================================================================

import argparse
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import data_generator as dg
from data_generator import DenoisingDataset


# Params
parser = argparse.ArgumentParser(description='PyTorch DnCNN')
parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--train_data', default='data/Train400', type=str, help='path of train data')
parser.add_argument('--sigma', default=15, type=int, help='noise level')
parser.add_argument('--epoch', default=180, type=int, help='number of train epoches')
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
args = parser.parse_args()

batch_size = args.batch_size
cuda = torch.cuda.is_available()
n_epoch = args.epoch
sigma = args.sigma

save_dir = os.path.join('models', args.model + '_' + 'sigma' + str(sigma)) # 权重文件保存路径

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

# 模型结构
class DnCNN(nn.Module):
    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []

        # 第一层使用了偏置,中间层没使用,默认bias为True
        # 一般是跟BN的Conv的bias为False,因为对BN的计算没用,可以减少计算量。如果不考虑计算量,加不加都行
        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            # eps避免计算标准差分母为0,一般为1e-5(默认)或1e-4
            # momentum更新率,一般为0.9或0.95;默认为0.1,即只使用新批次的10%数据,之前的90%来自移动平均值
            # 切块的数据形式,BN的动量接近1会比较好
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x) # 带噪声图像经过网络得到噪声
        return y-out    # 模型输出是预测的干净图像

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

# 损失函数定义
class sum_squared_error(_Loss):  # PyTorch 0.4.1
    """
    Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
    The backward is defined as: input-target
    """
    def __init__(self, size_average=None, reduce=None, reduction='sum'):
        super(sum_squared_error, self).__init__(size_average, reduce, reduction)

    def forward(self, input, target):
        # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
        return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)

# 找到最后一个保存的模型的epoch
def findLastCheckpoint(save_dir):
    file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            result = re.findall(".*model_(.*).pth.*", file_)
            epochs_exist.append(int(result[0]))
        initial_epoch = max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch


def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = DnCNN()
    
    initial_epoch = findLastCheckpoint(save_dir=save_dir)  # load the last model in matconvnet style
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        # model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
        model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)) # 如果有训练过的模型,载入最近的一个
    model.train()
    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
    criterion = sum_squared_error()
    if cuda:
        model = model.cuda()
         # device_ids = [0]
         # model = nn.DataParallel(model, device_ids=device_ids).cuda()
         # criterion = criterion.cuda()

    # 这里与论文不同,论文是SGD,1e-1降到1e-4,50个epoch
    # 本例是180个epoch,初始学习率为1e-3,每30个epoch降低20%
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)  # learning rates
    for epoch in range(initial_epoch, n_epoch):

        scheduler.step(epoch)  # step to the learning rate in this epcoh
        xs = dg.datagenerator(data_dir=args.train_data) # 读取数据,xs为(n,h,w,1)
        xs = xs.astype('float32')/255.0 # 归一化
        xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))  # nhwc ——> nchw
        DDataset = DenoisingDataset(xs, sigma) # 带噪和不带噪的所有图像块tensor
        DLoader = DataLoader(dataset=DDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True) # 训练集
        #数据集制作是将n个块分给128批次,DLoader的长度为 n / 128
        epoch_loss = 0
        start_time = time.time()

        for n_count, batch_yx in enumerate(DLoader): # 读取训练集
                optimizer.zero_grad()
                if cuda:
                    batch_x, batch_y = batch_yx[1].cuda(), batch_yx[0].cuda()
                # batch_x是干净图像,batch_y是带噪声图像
                # 损失是模型预测的去噪图像与干净图像之间的MSE,训练的目的就是让这个损失最小,模型参数最优
                loss = criterion(model(batch_y), batch_x)
                epoch_loss += loss.item()
                loss.backward()
                optimizer.step()
                if n_count % 10 == 0: # 每10个epoch输出一下,每个mini-batch下的平均损失
                    print('%4d %4d / %4d loss = %2.4f' % (epoch+1, n_count, xs.size(0)//batch_size, loss.item()/batch_size))
        elapsed_time = time.time() - start_time

        # 保存日志
        log('epcoh = %4d , loss = %4.4f , time = %4.2f s' % (epoch+1, epoch_loss/n_count, elapsed_time))
        np.savetxt('train_result.txt', np.hstack((epoch+1, epoch_loss/n_count, elapsed_time)), fmt='%2.4f')
        # torch.save(model.state_dict(), os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))

        # 保存每个epoch下的模型
        torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))

在测试集上测试DnCNN模型

本节对应main_test.py。主要包含模型推理,生成去噪后的图像,以及评估指标PSNR和SSIM的计算。

实现思路:遍历测试集图像,加噪输入模型,输出就是去噪后的图像,变换图像格式保存,原图和去噪后的图像之间计算指标。注意图像间的格式转换(numpy,skimage,pil,tensor等)。

对于高版本的Pytorch,源代码有以下两个报错

  1. 计算psnr和ssim报错:新版本的skimage的compare_psnr和compare_ssim与旧版的不同,已按新版修改,具体请看代码及注释。
  2. 参数save_result设置为1时,图像保存报错:原因为skimage.io的imsave保存图像时不支持np.float32,会有报错OSError: cannot write mode F as PNG,已修改。

修改报错后的main_test.py代码如下:

# -*- coding: utf-8 -*-

# =============================================================================
#  @article{zhang2017beyond,
#    title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
#    author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
#    journal={IEEE Transactions on Image Processing},
#    year={2017},
#    volume={26}, 
#    number={7}, 
#    pages={3142-3155}, 
#  }
# by Kai Zhang (08/2018)
# cskaizhang@gmail.com
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# run this to test the model

import argparse
import os, time, datetime
# 图像读取方式不同,可能指标不同
# import PIL.Image as Image
import PIL.Image as pil_image
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import torch

# 旧版计算指标已失效
# from skimage.measure import compare_psnr, compare_ssim

from skimage.io import imread, imsave
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
    parser.add_argument('--set_names', default=['Set68', 'Set12'], help='directory of test dataset')
    parser.add_argument('--sigma', default=50, type=int, help='noise level')
    parser.add_argument('--model_dir', default=os.path.join('models', 'DnCNN_sigma50'), help='directory of the model')
    parser.add_argument('--model_name', default='model_180.pth', type=str, help='the model name')
    parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
    parser.add_argument('--save_result', default=1, type=int, help='save the denoised image, 1 or 0')
    return parser.parse_args()


def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)

# 保存结果
def save_result(result, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        # imsave(path, np.clip(result, 0, 1))
        imsave(path, result)


def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

# 测试时修改对应depth
class DnCNN(nn.Module):

    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []
        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)


if __name__ == '__main__':

    args = parse_args()

    # model = DnCNN()
    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):

        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        # model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
#    print(params.keys())
#
#    for key, value in params.items():
#        print(key)    # parameter name
#    print(params['dncnn.12.running_mean'])
#    print(model.state_dict())

    model.eval()  # evaluation mode
#    model.train()

    if torch.cuda.is_available():
        model = model.cuda()

    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)

    for set_cur in args.set_names: # 对于每一个测试集

        # 生成与测试集同名的结果文件夹
        if not os.path.exists(os.path.join(args.result_dir, set_cur)):
            os.mkdir(os.path.join(args.result_dir, set_cur))

        # 保存每个训练集的平均指标
        psnrs = []
        ssims = []

        # 获取到测试集中的图像
        for im in os.listdir(os.path.join(args.set_dir, set_cur)):
            # 如果后缀是图像格式
            if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):

                # 读取图像并转成float32的numpy并归一化
                x = np.array(imread(os.path.join(args.set_dir, set_cur, im)), dtype=np.float32)/255.0
                # 设置随机种子保证可重复性
                np.random.seed(seed=0)  # for reproducibility
                # 给图像添加对应level的噪声
                y = x + np.random.normal(0, args.sigma/255.0, x.shape)  # Add Gaussian noise without clipping
                # ,转成float32,与pytorch要求一致
                y = y.astype(np.float32)
                # 转成模型输入的tensor(b,c,h,w),b为1
                y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])

                torch.cuda.synchronize() # 阻塞程序确保GPU操作完成,目的是保证测量时间的准确性
                start_time = time.time()
                y_ = y_.cuda()
                x_ = model(y_)  # inference
                x_ = x_.view(y.shape[0], y.shape[1]) # 转成(h,w)
                x_ = x_.cpu()
                x_ = x_.detach().numpy().astype(np.float32) # 转成numpy
                torch.cuda.synchronize()
                elapsed_time = time.time() - start_time
                print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))

                # 计算去噪前后图像的psnr和SSIM
                psnr_x_ = compare_psnr(x, x_)
                ssim_x_ = compare_ssim(x, x_, data_range=x.max() - x.min())
                if args.save_result:
                    name, ext = os.path.splitext(im)

                    # show(np.hstack((y, x_)))  # 展示去噪前后对比

                    # TODO 2.修改,源码报错图像报错 OSError: cannot write mode F as PNG

                    # 先clip到0-1之间,然后*255,最后转成uint8
                    x_ = np.clip(x_, 0, 1)
                    x_ *= 255
                    x_ = x_.astype(np.uint8)
                    x = np.clip(x, 0, 1)
                    x *= 255
                    x = x.astype(np.uint8)
                    y = np.clip(y, 0, 1)
                    y *= 255
                    y = y.astype(np.uint8)

                    plt = np.hstack((x, y, x_))

                    save_result(x_, path=os.path.join(args.result_dir, set_cur, name+'_dncnn'+ext))  # save the denoised image
                psnrs.append(psnr_x_)
                ssims.append(ssim_x_)
        psnr_avg = np.mean(psnrs)
        ssim_avg = np.mean(ssims)
        psnrs.append(psnr_avg)
        ssims.append(ssim_avg)
        # psnrs和ssims是测试集按顺序的每张图像的值,最后一个值是测试集平均值
        if args.save_result:
            save_result(np.hstack((psnrs, ssims)), path=os.path.join(args.result_dir, set_cur, 'results.txt'))
        log('Datset: {0:10s} \n  PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))

σ=15,DnCNN-S在Set68上的评估结果:
在这里插入图片描述
σ=25,DnCNN-S在Set68上的评估结果:
在这里插入图片描述
σ=50,DnCNN-S在Set68上的评估结果:
在这里插入图片描述
与论文表2基本一致:
在这里插入图片描述
除了控制台显示的结果外,当save_result参数设置为1时,results文件夹下会有对应测试集的去噪结果和每张图像的评估指标(results.txt)。

去噪结果展示,σ=15,25,50,从左至右依次为Groud truth,noisy image,denoising result

“parrot”:
在这里插入图片描述

在这里插入图片描述

请添加图片描述
BSD68:
在这里插入图片描述
在这里插入图片描述
请添加图片描述
对于Set12数据集,σ=15,25,50的每个results.txt中的PSNR也与论文中表3基本一致。

补充内容

训练和测试DnCNN-B

DnCNN-B与DnCNN-S的区别是随机噪声范围,即给每个patch添加[0,55]内随机噪声level的噪声,并且patch_size由40变为50,stride由10变为7,BSD400的图像块数量为386688,与128×3000接近。此外,DnCNN的层数depth由17变为20。

修改后的DenoisingDataset类如下:

class DenoisingDataset(Dataset):
    """Dataset wrapping tensors.
    Arguments:
        xs (Tensor): clean image patches, (n,c,h,w)四维张量,在训练前制作好,n是总图像块数量
        sigma: noise level, e.g., 25
    """
    def __init__(self, xs, sigma):
        super(DenoisingDataset, self).__init__()
        self.xs = xs
        self.sigma = sigma

    def __getitem__(self, index):
        batch_x = self.xs[index]  # 每个图像块
        # TODO 1.修改噪声水平的输入, 每个图像块添加随机噪声level
        if len(self.sigma) == 1:
            sigma = self.sigma[0]
        else:
            sigma = random.randint(self.sigma[0], self.sigma[1])
        # 噪声生成:生成与batch_x相同形状满足标准正太分布的张量,然后按元素乘[0,255]像素范围内的噪声标准差
        noise = torch.randn(batch_x.size()).mul_(sigma/255.0)
        batch_y = batch_x + noise
        return batch_y, batch_x

    def __len__(self):
        return self.xs.size(0)

main_train.py中的sigma参数由int变为str,值为’0,55’,然后存为列表形式,以便读取上下界。

BSD68平均PSNR测试结果:

Datasetnoise levelDnCNN-S(Paper)DnCNN-S(本文复现)DnCNN-B(Paper)DnCNN-B(本文复现)
BSD68σ = 1531.7331.7331.6131.61
σ = 2529.2329.2429.1629.14
σ = 5026.2326.2526.2326.21

Set12单张图像PSNR以及平均PSNR测试结果:

ImagesC.manHousePeppersStarfishMonar.Airpl.ParrotLenaBarbaraBoatManCoupleAverage
noise levelσ = 15
DnCNN-S(Paper)32.6134.9733.3032.2033.0931.7031.8334.6232.6432.4232.4632.4732.859
DnCNN-S(本文复现)32.6534.9933.3032.1433.2431.7031.8934.5732.7032.4332.4432.4432.879
DnCNN-B(Paper)32.1034.9333.1532.0232.9431.5631.6334.5632.0932.3532.4132.4132.680
DnCNN-B(本文复现)32.1934.9633.1831.9533.0931.5731.7034.5032.3332.3632.3632.3832.719
noise levelσ = 25
DnCNN-S(Paper)30.1833.0630.8729.4130.2829.1329.4332.4430.0030.2130.1030.1230.436
DnCNN-S(本文复现)30.2633.1330.8129.3930.4529.0829.4432.4230.0530.2130.0830.0930.455
DnCNN-B(Paper)29.9433.0530.8429.3430.2529.0929.3532.4229.6930.2030.0930.1030.362
DnCNN-B(本文复现)30.0433.0430.7929.2430.3529.0329.3432.3729.7830.1630.0530.0530.359
noise levelσ = 50
DnCNN-S(Paper)27.0330.0027.3225.7026.7825.8726.4829.3926.2227.2027.2426.9027.178
DnCNN-S(本文复现)27.3130.0027.3925.7026.8625.8326.4729.3426.2527.2127.2026.8927.210
DnCNN-B(Paper)27.0330.0227.3925.7226.8325.8926.4829.3826.3827.2327.2326.9127.206
DnCNN-B(本文复现)27.2829.9527.4125.6826.8325.8226.4329.3326.2627.2027.1726.8827.190

训练和测试CDnCNN-B

训练集使用CBSD432,测试集使用CBSD68

data_generator.py的改动位置

  • gen_patches的cv2.imread(file_name, 1),flags参数由0改为1,因为是RGB图像
  • gen_patches的img.shape为hwc,多了一个通道维度
  • img_scaled = cv2.resize(img, (h_scaled, w_scaled), interpolation=cv2.INTER_CUBIC)改为
    img_scaled = cv2.resize(img, (w_scaled, h_scaled), interpolation=cv2.INTER_CUBIC) ,因为CBSD432中的图像宽高不同,cv2的resize是(w,h),而不是(h,w)。BSD400是180×180的图像,所以不影响。源码的issue中也提到了这个错误。
  • datagenerator中的file_list = glob.glob(data_dir+'/*.png')改为file_list = glob.glob(data_dir+'/*.*') ,因为CBSD432中的图像是.jpg格式,这样修改适配任意图像格式的数据集。
  • datagenerator中的data = np.expand_dims(data, axis=3)注释掉,因为data已经是nhw3了

main_train.py的改动位置

  • DnCNN的image_channels参数由1改为3,depth为20

经过测试,由于main_test.py对于彩色图像的输入和读取方式有问题,产生的去噪图像结构模糊。所以,我们单写一个测试RGB图像的脚本,使用PIL库来读取图像,好处是图像在Numpy和Tensor转换、可视化等方面比较方便。

新建一个名为test_RGB.py文件:

# -*- coding: utf-8 -*-

# =============================================================================
#  @article{zhang2017beyond,
#    title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
#    author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
#    journal={IEEE Transactions on Image Processing},
#    year={2017},
#    volume={26},
#    number={7},
#    pages={3142-3155},
#  }
# by Kai Zhang (08/2018)
# cskaizhang@gmail.com
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# run this to test the model

import argparse
import os, time, datetime
# 图像读取方式不同,可能指标不同
# import PIL.Image as Image
import PIL.Image as pil_image
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import torch
from torchvision import transforms

# 旧版计算指标已失效
# from skimage.measure import compare_psnr, compare_ssim

from skimage.io import imread, imsave
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
    parser.add_argument('--set_names', default=['CBSD68'], help='directory of test dataset')
    parser.add_argument('--sigma', default=45, type=int, help='noise level')
    parser.add_argument('--model_dir', default=os.path.join('models', 'CDnCNN-B_sigma[0, 55]_old'),
                        help='directory of the model')
    parser.add_argument('--model_name', default='model_180.pth', type=str, help='the model name')
    parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
    parser.add_argument('--save_result', default=1, type=int, help='save the denoised image, 1 or 0')
    return parser.parse_args()


def log(*args, **kwargs):
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


# 保存结果
def save_result(result, path):
    path = path if path.find('.') != -1 else path + '.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        # imsave(path, np.clip(result, 0, 1))
        imsave(path, result)


def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    # plt.imshow(x, interpolation='nearest', cmap='gray')
    plt.imshow(x, interpolation='nearest')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


# 测试时修改对应depth
class DnCNN(nn.Module):

    def __init__(self, depth=20, n_channels=64, image_channels=3, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []
        layers.append(
            nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding,
                      bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth - 2):
            layers.append(
                nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding,
                          bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(
            nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding,
                      bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y - out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)


if __name__ == '__main__':

    args = parse_args()

    # model = DnCNN()
    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):

        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        # model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

    #    params = model.state_dict()
    #    print(params.values())
    #    print(params.keys())
    #
    #    for key, value in params.items():
    #        print(key)    # parameter name
    #    print(params['dncnn.12.running_mean'])
    #    print(model.state_dict())

    model.eval()  # evaluation mode
    #    model.train()

    if torch.cuda.is_available():
        model = model.cuda()

    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)

    for set_cur in args.set_names:  # 对于每一个测试集

        # 生成与测试集同名的结果文件夹
        if not os.path.exists(os.path.join(args.result_dir, set_cur + "_sigma{}".format(args.sigma))):
            os.mkdir(os.path.join(args.result_dir, set_cur + "_sigma{}".format(args.sigma)))

        # 保存每个训练集的平均指标
        psnrs = []
        ssims = []

        # 获取到测试集中的图像
        for im in os.listdir(os.path.join(args.set_dir, set_cur)):
            name, ext = os.path.splitext(im)
            # 如果后缀是图像格式
            if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
                x = pil_image.open(os.path.join(args.set_dir, set_cur, im)).convert('RGB')
                GT = x
                noise = np.random.normal(0.0, args.sigma, (x.height, x.width, 3)).astype(np.float32)
                y = np.array(x).astype(np.float32) + noise
                y /= 255.0
                input = y
                y_ = transforms.ToTensor()(y).unsqueeze(0).cuda()
                with torch.no_grad():
                    x_ = model(y_)

                output = x_.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
                output = pil_image.fromarray(output, mode='RGB')
                output.save(os.path.join(args.result_dir, set_cur + "_sigma{}".format(args.sigma),
                                                       name + '_dncnn' + ext))


                # 对比图 顺序为GT、Bicubic、RDN
                fig, axes = plt.subplots(1, 3)
                # 关闭坐标轴
                for ax in axes:
                    ax.axis('off')

                # 在每个子图中显示对应的图像
                axes[0].imshow(GT)
                axes[0].set_title('Clean Image')
                axes[1].imshow(input)
                axes[1].set_title('Noisy Image')
                axes[2].imshow(output)
                axes[2].set_title('CDnCNN-B result')

                # 保存图像
                plt.savefig(os.path.join(args.result_dir, set_cur + "_sigma{}".format(args.sigma),
                                         name + '_plt_dncnn' + ext),bbox_inches='tight', dpi=600)

CDnCNN-B,在CBSD68上的部分测试结果如下:

sigma = 15

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

sigma = 25

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

sigma = 35

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

sigma = 45

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

sigma = 55

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

训练和测试DnCNN-3

准备工作和训练

参数:

训练集:291数据集(T91+BSD200)
noise level:[0,55]
scale:2,3,4
quality_factor:[5,99]
patch_size: 50×50
patch数量:128×8000

注:由于128×8000数据量过大,我的设备会爆显存。所以训练DnCNN-3时增大了步长,减少了patch数量,模型的性能一般。所以,本节只展示图像效果,而没有量化评估结果。(但有计算PSNR/SSIM的代码)

对于去噪: 加噪方式同前面的模型。

对于超分: 由于DnCNN本身不具有超分的能力,所以输入是Bicubic先缩小再放大(同SRCNN,也就是x1的去模糊,x234表示退化的模糊程度)

对于JPEG去块: 将图像保存成JPEG格式,并指定JPEG图像的质量,以实现JPEG encoder

由于源码的DenoisingDataset中图像是以Tensor形式加噪的,那么对于超分和JPEG去块,我们先将Tensor图像转为PIL,在缩放或JPEG encoder后,再转回Tensor形式,实现输入的制作。

修改后的DenoisingDataset如下:

class DenoisingDataset(Dataset):
    """Dataset wrapping tensors.
    Arguments:
        xs (Tensor): clean image patches, (n,c,h,w)四维张量,在训练前制作好,n是总图像块数量
        sigma: noise level, e.g., 25
        downsampling_factor: 超分下采样因子
        jpeg_quality:JPEG去块范围
    """
    def __init__(self, xs, sigma, downsampling_factor, jpeg_quality):
        super(DenoisingDataset, self).__init__()
        self.xs = xs
        self.sigma = sigma
        self.downsampling_factor = downsampling_factor
        self.jpeg_quality = jpeg_quality

    def __getitem__(self, index):
        batch_x = self.xs[index]  # 每个图像块
        noisy_x = batch_x

        if self.sigma is not None:
            # TODO 1.修改噪声水平的输入, 每个图像块添加随机噪声level
            if len(self.sigma) == 1:
                sigma = self.sigma[0]
            else:
                sigma = random.randint(self.sigma[0], self.sigma[1])
            # 噪声生成:生成与batch_x相同形状满足标准正太分布的张量,然后按元素乘[0,255]像素范围内的噪声标准差
            # noise = torch.randn(batch_x.size()).mul_(sigma/255.0)
            noise = torch.randn(noisy_x.size()).mul_(sigma / 255.0)
            noisy_x = noisy_x + noise

        # TODO 3.在此修改,对照另一个源码,添加超分和JPEG去块
        if self.downsampling_factor is not None:
            if len(self.downsampling_factor) == 1:
                downsampling_factor = self.downsampling_factor[0]
            else:
                downsampling_factor = random.randint(self.downsampling_factor[0], self.downsampling_factor[1])

            noisy_x = transforms.ToPILImage()(noisy_x).convert("RGB")
            noisy_x = noisy_x.resize((patch_size // downsampling_factor,
                                              patch_size // downsampling_factor),
                                             resample=pil_image.BICUBIC)
            noisy_x = noisy_x.resize((patch_size, patch_size), resample=pil_image.BICUBIC)
            noisy_x = transforms.ToTensor()(noisy_x)

        if self.jpeg_quality is not None:
            if len(self.jpeg_quality) == 1:
                quality = self.jpeg_quality[0]
            else:
                quality = random.randint(self.jpeg_quality[0], self.jpeg_quality[1])

            noisy_image = transforms.ToPILImage()(noisy_x).convert("RGB")
            buffer = io.BytesIO()
            noisy_image.save(buffer, format='jpeg', quality=quality)
            noisy_image = pil_image.open(buffer)
            noisy_x = transforms.ToTensor()(noisy_image)

        # batch_y = batch_x + noise
        return noisy_x, batch_x

    def __len__(self):
        return self.xs.size(0)

添加了超分的放大倍数和JPEG质量参数,对应的main_train.py中也添加这两个参数,在制作数据集时将这两个参数传入即可。

测试DnCNN-3

新建test_3.py,用于测试DnCNN-3,代码如下:

# 测试DnCNN-3,计算测试集平均PSNR/SSIM,保存结果
# 复现论文表5和图10-12结果

import argparse
import os, time, datetime, io
# 图像读取方式不同,可能指标不同
# import PIL.Image as Image
import PIL.Image as pil_image
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import torch
from torchvision import transforms

# 旧版计算指标已失效
# from skimage.measure import compare_psnr, compare_ssim

from skimage.io import imread, imsave
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
    parser.add_argument('--set_names', default=['URBAN100'], help='directory of test dataset')
    parser.add_argument('--sigma', type=int, default=None, help='noise level')
    parser.add_argument('--jpeg_quality', type=int, default=None)
    parser.add_argument('--downsampling_factor', type=int, default=4)
    parser.add_argument('--model_dir', default=os.path.join('models', 'DnCNN-3_sigma055'),
                        help='directory of the model')
    parser.add_argument('--model_name', default='model_026.pth', type=str, help='the model name')
    parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
    parser.add_argument('--save_result', default=1, type=int, help='save the denoised image, 1 or 0')
    return parser.parse_args()


def log(*args, **kwargs):
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


# 保存结果
def save_result(result, path):
    path = path if path.find('.') != -1 else path + '.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        # imsave(path, np.clip(result, 0, 1))
        imsave(path, result)


def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    # plt.imshow(x, interpolation='nearest', cmap='gray')
    plt.imshow(x, interpolation='nearest')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


# 测试时修改对应depth
class DnCNN(nn.Module):

    def __init__(self, depth=20, n_channels=64, image_channels=3, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []
        layers.append(
            nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding,
                      bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth - 2):
            layers.append(
                nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding,
                          bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(
            nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding,
                      bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y - out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)


if __name__ == '__main__':

    args = parse_args()

    # model = DnCNN()
    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):

        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        # model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

    #    params = model.state_dict()
    #    print(params.values())
    #    print(params.keys())
    #
    #    for key, value in params.items():
    #        print(key)    # parameter name
    #    print(params['dncnn.12.running_mean'])
    #    print(model.state_dict())

    model.eval()  # evaluation mode
    #    model.train()

    if torch.cuda.is_available():
        model = model.cuda()

    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)

    for set_cur in args.set_names:  # 对于每一个测试集

        # 生成与测试集同名的结果文件夹
        if not os.path.exists(os.path.join(args.result_dir, set_cur)):
            os.mkdir(os.path.join(args.result_dir, set_cur))

        # 保存每个训练集的平均指标
        psnrs = []
        ssims = []

        # 获取到测试集中的图像
        for im in os.listdir(os.path.join(args.set_dir, set_cur)):
            name, ext = os.path.splitext(im)
            descriptions = ''
            # 如果后缀是图像格式
            if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
                x = pil_image.open(os.path.join(args.set_dir, set_cur, im)).convert('RGB')
                GT = x
                if args.sigma is not None:
                    noise = np.random.normal(0.0, args.sigma, (x.height, x.width, 3)).astype(np.float32)
                    y = np.array(x).astype(np.float32) + noise
                    y /= 255.0
                    input = y
                    descriptions += '_sigma_{}'.format(args.sigma)

                if args.jpeg_quality is not None:
                    buffer = io.BytesIO()
                    x.save(buffer, format='jpeg', quality=args.jpeg_quality)
                    x = pil_image.open(buffer)
                    y = np.array(x).astype(np.float32)
                    y /= 255.0
                    input = y
                    descriptions += '_jpeg_{}'.format(args.jpeg_quality)

                if args.downsampling_factor is not None:
                    original_width = x.width
                    original_height = x.height
                    x = x.resize((x.width // args.downsampling_factor,
                                          x.height // args.downsampling_factor),
                                         resample=pil_image.BICUBIC)
                    x = x.resize((original_width, original_height), resample=pil_image.BICUBIC)
                    y = np.array(x).astype(np.float32)
                    y /= 255.0
                    input = y
                    descriptions += '_sr_x{}'.format(args.downsampling_factor)

                y_ = transforms.ToTensor()(y).unsqueeze(0).cuda()
                with torch.no_grad():
                    x_ = model(y_)

                # output = x_.mul_(255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
                output = x_.mul_(255.0).clamp_(0.0, 255).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
                output = pil_image.fromarray(output, mode='RGB')
                output.save(os.path.join(args.result_dir, set_cur,
                                                       name + '_dncnn_{}'.format(descriptions) + '.png'))


                # 对比图 顺序为GT、Bicubic、RDN
                fig, axes = plt.subplots(1, 3)
                # 关闭坐标轴
                for ax in axes:
                    ax.axis('off')

                # 在每个子图中显示对应的图像
                axes[0].imshow(GT)
                axes[0].set_title('Ground-truth')
                axes[1].imshow(input)
                axes[1].set_title('{}'.format(descriptions))
                axes[2].imshow(output)
                axes[2].set_title('DnCNN-3 result')

                # 保存图像
                plt.savefig(os.path.join(args.result_dir, set_cur, name + '_plt_dncnn' + '.png'), bbox_inches='tight', dpi=600)


                # 计算指标
                # x = np.array(x).astype(np.float32) / 255.
                # x_ = x_.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
                # x_ = x_.astype(np.float32) / 255.
                # psnr_x_ = compare_psnr(x, x_)
                # ssim_x_ = compare_ssim(x, x_, data_range=x.max() - x.min(), channel_axis=2)
                psnr_x_ = compare_psnr(np.array(GT), np.array(output))
                ssim_x_ = compare_ssim(np.array(GT), np.array(output), data_range=np.array(GT).max() - np.array(GT).min(), channel_axis=2)
                psnrs.append(psnr_x_)
                ssims.append(ssim_x_)
        psnr_avg = np.mean(psnrs)
        ssim_avg = np.mean(ssims)
        psnrs.append(psnr_avg)
        ssims.append(ssim_avg)
        # psnrs和ssims是测试集按顺序的每张图像的值,最后一个值是测试集平均值
        if args.save_result:
            save_result(np.hstack((psnrs, ssims)),
                        path=os.path.join(args.result_dir, set_cur, 'results.txt'))
        log('Datset: {0:10s} \n  PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))

测试集放在data/Test下,根据需要修改三种任务的参数即可。执行后results文件夹会保存对应测试集的结果。

DnCNN-3结果展示:

x4超分:
在这里插入图片描述
在这里插入图片描述
q=40,JPEG去块(可能看不太清,放大看右下角JPEG块还是挺明显的。):
在这里插入图片描述
在这里插入图片描述
结果不展示太多了,大家可以根据我给出的模型自行测试。

其他

还有可以挖掘和学习的东西,大家感兴趣的可以自行实现。

  1. 复现论文图2:使用BSD68作为验证集,保存每个epoch的PSNR,最终画出变化曲线。原作者项目中的Pytorch代码没有验证过程,可以参照测试代码中的计算指标,在每个epoch训练后添加验证。
  2. 复现论文图13:将输入图像分成6块,每个块各自添加三种任务的不同参数。初始化一个与输入图像一样大的全零Numpy,然后按照对应位置赋值,最终将该图像输出DnCNN-3,输出就是三种任务的重建结果。Residual Image用输入减输出(或者输出减输入)即可可视化。
  3. 复现表4:计算推理时间。

代码及各模型下训练好的权重文件下载

和源码相比,由于改动较多(包括修正源码bug适配高版本的Pytorch、添加补充内容等),所以代码+模型整体打包下载,怕有遗漏。

算是请我喝杯咖啡,继续为大家产出优质的去噪复现,感谢大家支持!

下载地址:图像去噪DnCNN的Pytorch完复现代码,源码基础上添加DnCNN-B/CDnCNN-B、DnCNN-3的训练和测试复现

在这里插入图片描述

代码使用说明

每个模型的训练和测试请按步骤执行。

训练集放到data文件夹中,测试集放在data/Test文件夹中。

DnCNN-S

  1. data_generator.py中

patch_size, stride = 40, 10;
get_patches中imread的flags改为0,代表读取的是灰度图;
h, w = img.shape,cv2读灰度图形状是2维

  1. main_train.py
    (1)先修改DnCNN中image_channels=1(通道为3也可以训练灰度图
    (2)参数设置(windows运行请修改对应位置的参数,空的值为None,Linux运行执行下面命令):
python main_train.py --arch "DnCNN-S" \               
		             --train_data "data/Train400" \
		             --sigma 15 \
		             --downsampling_factor "" \
		             --jpeg_quality "" \
		             --epoch 180 \
		             --lr 1e-3 \    
  1. main_test.py
python main_test.py --set_dir "data/Test" \               
		            --set_names "['Set68','Set12']" \
		            --sigma 15 \
		            --model_dir "models\DnCNN-S_sigma15" \
		            --model_name "model_180.pth" \
		            --epoch 180 \
		            --result_dir "results" \    
		            --save_result 1

测试不建议用linux,windows下直接修改对应参数运行。

DnCNN-B

  1. data_generator.py中

patch_size, stride = 50, 7;
get_patches中imread的flags改为0,代表读取的是灰度图;
h, w = img.shape,cv2读灰度图形状是2维

  1. main_train.py
    (1)先修改DnCNN中image_channels=1(通道为3也可以训练灰度图
    (2)参数设置(windows运行请修改对应位置的参数,空的值为None,Linux运行执行下面命令):
python main_train.py --arch "DnCNN-B" \               
		             --train_data "data/Train400" \
		             --sigma 0,55 \
		             --downsampling_factor "" \
		             --jpeg_quality "" \
		             --epoch 180 \
		             --lr 1e-3 \    
  1. main_test.py
python main_test.py --set_dir "data/Test" \               
		            --set_names "['Set68','Set12']" \
		            --sigma 15 \
		            --model_dir "models\DnCNN-B_sigma[0, 55]" \
		            --model_name "model_180.pth" \
		            --epoch 180 \
		            --result_dir "results" \    
		            --save_result 1

测试不建议用linux,windows下直接修改对应参数运行。

CDnCNN-B

  1. data_generator.py中

patch_size, stride = 50, 19;
get_patches中imread的flags改为1,代表读取的是RGB图;
h, w, c = img.shape,cv2读灰度图形状是2维

  1. main_train.py
    (1)先修改DnCNN中image_channels=3
    (2)参数设置(windows运行请修改对应位置的参数,空的值为None,Linux运行执行下面命令):
python main_train.py --arch "CDnCNN-B" \               
		             --train_data "data/CBSD432" \
		             --sigma 0,55 \
		             --downsampling_factor "" \
		             --jpeg_quality "" \
		             --epoch 180 \
		             --lr 1e-3 \    
  1. test_RGB.py
python test_RGB.py --set_dir "data/Test" \               
		           --set_names "['CBSD60']" \
		           --sigma 15 \
		           --model_dir "models\CDnCNN-B_sigma[0, 55]" \
		           --model_name "model_180.pth" \
		           --epoch 180 \
		           --result_dir "results" \    
		           --save_result 1

测试不建议用linux,windows下直接修改对应参数运行。

DnCNN-3

  1. data_generator.py中

patch_size, stride = 50, 4; (爆显存就增大步长,这个值是针对291数据集的)
get_patches中imread的flags改为1,代表读取的是RGB图;
h, w, c = img.shape,cv2读灰度图形状是2维

  1. main_train.py
    (1)先修改DnCNN中image_channels=3
    (2)参数设置(windows运行请修改对应位置的参数,空的值为None,Linux运行执行下面命令):
python main_train.py --arch "DnCNN-3" \               
		             --train_data "data/291" \
		             --sigma 0,55 \
		             --downsampling_factor 1,4 \
		             --jpeg_quality 5,99 \
		             --epoch 180 \
		             --lr 1e-3 \    
  1. test_3.py
python test_3.py --set_dir "data/Test" \               
		         --set_names "['Set5']" \
		         --sigma "" \
		         --jpeg_quality "" \
		         --downsampling_factor 4 \
		         --model_dir "models\DnCNN-3_sigma[0, 55]" \
		         --model_name "model_180.pth" \
		         --epoch 180 \
		         --result_dir "results" \    
		         --save_result 1

上述例子为只进行x4超分,如果需要其他任务请修改对应的参数。

至此本文结束。

如果本文对你有所帮助,请点赞收藏,创作不易,感谢您的支持!

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

十小大

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值