【图像去噪】论文复现:全网最细的Pytorch版本实现MemNet!论文中的网络结构图与代码中的每个变量一一对应!实现思路一目了然!附完整代码和训练好的模型权重文件!

请先看【专栏介绍文章】:【图像去噪(Image Denoising)】关于【图像去噪】专栏的相关说明,包含适配人群、专栏简介、专栏亮点、阅读方法、定价理由、品质承诺、关于更新、去噪概述、文章目录、资料汇总、问题汇总(更新中)

完整代码和训练好的模型权重文件下载链接见本文底部,订阅专栏免费获取!

本文亮点:

  • 跑通训练和测试代码,轻松运行,按步骤执行保证无任何运行问题
  • Pytorch实现Basic MemNet architecture和Multi-supervised MemNet architecture,与原论文基本一致
  • 论文中的网络结构图与实现代码深度绑定,变量标注完整,复现逻辑清晰
  • 更换路径和参数即可训练自己的图像数据集支持灰度图和RGB图
  • 包含训练好的模型文件(共6个,每个模型对应3个噪声level=30,50,70),可直接推理运行得到去噪后的图像结果以及评价指标PSNR/SSIM
  • 数据处理、模型训练和验证、推理测试全流程讲解,无论是科研还是应用,新手小白都能看懂,学习阅读毫无压力,去噪入门必看


前言

论文题目:MemNet: A Persistent Memory Network for Image Restoration —— MemNet:用于图像恢复的持久记忆网络

论文地址:MemNet: A Persistent Memory Network for Image Restoration

论文源码:https://github.com/tyshiwo/MemNet

对应的论文精读:【图像去噪】论文精读:MemNet: A Persistent Memory Network for Image Restoration

由于源码是Caffe实现的,本文实现Pytorch版本的MemNet。

一、跑通代码 (Quick Start)

项目文件说明:
在这里插入图片描述

  • data:去噪后图像结果保存位置
  • datasets:数据集所在文件夹
  • Plt:训练过程指标曲线可视化位置(Loss、PSNR、SSIM与Epoch关系曲线)
  • weights:训练模型保存位置
  • dataset.py:封装数据集
  • draw_evaluation.py:绘制指标曲线
  • test_benchmark.py:计算测试集指标;保存去噪后图像
  • memnet.py:MemNet模型基础版本实现
  • memnet1.py:MemNet模型多监督版本实现
  • README.md:相关说明
  • train.py:训练MemNet
  • utils.py:工具类

1.1 数据集准备

下载BSD500数据集,将全部图像转成灰度图(也可以用RGB图像训练,模型的输入通道改为3,本文与原论文一致,使用灰度图)。

  • 训练集:BSD500中的train,路径为:datasets/train/BSD200
  • 验证集:BSD500中的val,路径为:datasets/val/BSD100
  • 测试集:BSD500中的test、S14,路径为datasets/test/BSD200datasets/test/S14

将数据集放在对应的位置,如果要训练自己的数据集,按相同位置放置即可。数据集都是clean_image,通过代码内部加噪,无需加噪后的图像。

1.2 训练

根据需要设置train.py中的参数:

  • arch:模型选择(MemNet_MS(多监督结构), MemNet_BS(基础结构)),默认为MemNet_BS
  • images_dir:训练集路径
  • clean_valid_dir:验证集路径
  • outputs_dir:模型保存位置
  • gaussian_noise_level:高斯白噪声水平sigma(30,50,70)
  • patch_size:块大小,默认为31
  • stride:裁剪图像块步长,默认为21
  • batch_size:默认为64
  • num_epochs:训练轮数,本例设置为50
  • start_epoch:开始轮数(如果训练因为某种原因中断,接续训练的模型所在的epoch)
  • resume:接续训练的模型路径
  • lr:学习率(SGD为1e-1,Adam为1e-4)
  • lr-decay-steps:学习率多少个epoch后开始下降
  • lr-decay-gamma:下降比例,比如0.1就是下降10倍
  • clip:梯度裁剪参数,使用SGD优化器时使用
  • momentum:SGD的动量,默认为0.9
  • weight-decay:SGD权重衰减,默认为1e-4
  • threads:num_workers,默认为8
  • epoch_save_num:每多少个epoch保存一次指标和模型

参数设置好后(Linux用readme中的命令,Windows直接改参数default),运行main.py即可训练,控制台会显示训练进度,以及验证集指标。

1.3 测试

设置test_benchmark.py中的参数:

  • arch:模型选择
  • weights_path:训练好的模型文件路径(单个文件)
  • images_dir:测试集路径(文件夹)
  • outputs_denoising_dir:images_dir中每张图像的去噪结果路径
  • outputs_plt_dir:images_dir中每张图像的去噪结果对比路径
  • gaussian_noise_level:噪声水平,与模型的噪声水平一致

设置好参数后,执行test_benchmark.py。data文件夹中会出现去噪结果,控制台会输出测试集信息以及计算好的平均PSNR和SSIM。如果想测试单张图像,那么将对应的测试集文件夹中放置一张图像即可。

结果展示:
在这里插入图片描述
在这里插入图片描述

二、代码解析

2.1 数据预处理

本节对应dataset.py。功能为:

  • 实现数据增强
  • 实现图像切块操作
  • 将训练集和验证集封装成tensor

2.1.1 数据增强

数据增强为旋转、翻转以及它们之间组合的全排列,一共8种。实现如下:

def data_aug(img, mode=0):
    # data augmentation
    if mode == 0:
        return img
    elif mode == 1:
        return np.flipud(img)
    elif mode == 2:
        return np.rot90(img)
    elif mode == 3:
        return np.flipud(np.rot90(img))
    elif mode == 4:
        return np.rot90(img, k=2)
    elif mode == 5:
        return np.flipud(np.rot90(img, k=2))
    elif mode == 6:
        return np.rot90(img, k=3)
    elif mode == 7:
        return np.flipud(np.rot90(img, k=3))

通过随机数0~8选择mode,实现随机数据增强。

2.1.2 制作图像块

思路:指定图像块大小(MemNet为31×31),根据步长stride (21),类似滑动窗口一样扫描整张图像,动一下切一下,将所有的图像块放到一起作为整个训练集。

实现如下:

def gen_patch(datasets_path, patch_size, stride):
    file_list = sorted(glob.glob(datasets_path + '/*')) # 图像列表
    data = [] # 返回制作好的图像块集合
    patches = [] # 每一张图像的图像块集合

    # 按stride和patch_size得到每张图像的块集合
    for i in range(len(file_list)):
        clean_image = pil_image.open(file_list[i]).convert('RGB')
        for j in range(0, clean_image.width - patch_size + 1, stride):
            for k in range(0, clean_image.height - patch_size + 1, stride):
                x = clean_image.crop((j, k, j + patch_size, k + patch_size))
                patches.append(x)
                for m in range(0, 1):
                    x_aug = data_aug(x, mode=np.random.randint(0, 8))
                    patches.append(x_aug)

    # 得到所有的图像块集合
    for patch in patches:
        data.append(patch)

    return data

2.2 MemNet网络结构拆解

本节对应MemNet.py。

2.2.1 Memory Block

记忆块:
在这里插入图片描述

  • 记忆块包含递归单元门控单元两部分
  • 递归单元的基本结构为残差块,残差块是由两个BN+ReLU+3×3Conv构成的,输入和输出通过残差连接(ResNet,论文Sec3.2.上图中的每个蓝圈)
  • 门控单元的基本结构为残差块BN+ReLU+1×1Conv

代码实现如下:

# 记忆块
class MemoryBlock(nn.Module):
    def __init__(self, channels, num_resblock, num_memblock):
        # num_memblock为当前记忆块后还有的记忆块个数
        super(MemoryBlock, self).__init__()

        # 递归单元是由残差块构成的
        self.recursive_unit = nn.ModuleList(
            [ResidualBlock(channels) for i in range(num_resblock)]
        )

        # 门控单元的输入通道数为M+R
        self.gate_unit = GateUnit((num_resblock+num_memblock) * channels, channels, True)

    def forward(self, x, ys):
        # xs是递归单元的每个层输出,ys是每个记忆块的门控输出
        xs = []
        residual = x
        for layer in self.recursive_unit:
            x = layer(x)
            xs.append(x)

        # 门控单元的输入应该是短期记忆和长期记忆的输出的堆叠
        gate_out = self.gate_unit(torch.cat(xs+ys, 1))
        ys.append(gate_out)
        return gate_out

# 残差块(ResNet,Sec 3.2)
class ResidualBlock(torch.nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.relu_conv1 = BNReLUConv(channels, channels, True)
        self.relu_conv2 = BNReLUConv(channels, channels, True)
        
    def forward(self, x):
        residual = x
        out = self.relu_conv1(x)
        out = self.relu_conv2(out)
        out = out + residual
        return out

# 基本结构为BN+ReLU+3×3Conv
class BNReLUConv(nn.Sequential):
    def __init__(self, in_channels, channels, inplace=True):
        super(BNReLUConv, self).__init__()
        self.add_module('bn', nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=inplace))
        self.add_module('conv', nn.Conv2d(in_channels, channels, 3, 1, 1))

# 门控结构为BN+ReLU+1×1Conv
class GateUnit(nn.Sequential):
    def __init__(self, in_channels, channels, inplace=True):
        super(GateUnit, self).__init__()
        self.add_module('bn',nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=inplace))
        self.add_module('conv', nn.Conv2d(in_channels, channels, 1, 1, 0))

2.2.2 Basic MemNet architecture

基本MemNet结构:
在这里插入图片描述
图2:蓝色标注为代码实现中对应的变量。图中矩形结构对应__init__中的模块定义;箭头对应forward中前向传播过程。

  • MemNet的结构包含三个部分:特征提取网络FENet多个堆叠的记忆块重建网络ReconNet
  • FENet:BN+ReLU+3×3Conv
  • 多个堆叠的记忆块:M个记忆块,每个记忆块的直接输出为短期连接(记忆块本身结构输出)、与后面的每个记忆块做长期连接(门控输出)
  • ReconNet:BN+ReLU+3×3Conv

注:递归单元的最基本模块、门控单元、FENet、ReconNet的基本结构都是BN+ReLU+Conv。找到公共模块,可以简化模型复现。

代码如下:

class MemNet(nn.Module):
    def __init__(self, in_channels, channels, num_memblock, num_resblock):
        super(MemNet, self).__init__()
        self.feature_extractor = BNReLUConv(in_channels, channels, True)  # FENet
        self.reconstructor = BNReLUConv(channels, in_channels, True)      # ReconNet

        # 全部记忆块
        self.dense_memory = nn.ModuleList(
            [MemoryBlock(channels, num_resblock, i+1) for i in range(num_memblock)]
        )

    #Base MemNet architecture
    def forward(self, x):
        residual = x   # 残差输出
        out = self.feature_extractor(x) # FENet输出
        ys = [out]  # FENet输出和每个残差块输出列表
        for memory_block in self.dense_memory:
            out = memory_block(out, ys) # 每个记忆块输出
        out = self.reconstructor(out)   # 重建层输出
        out = out + residual # 最终输出
        
        return out

2.2.3 Multi-supervised MemNet architecture

多监督MemNet结构:
在这里插入图片描述

  • 多监督结构与基本结构的区别在于:每个记忆块的输出都送入重建网络得到M个输出,通过权重分配将这些输出按元素相加得到最终输出。
  • 代码区别:多了权重weights,多了每个记忆块的输出
  • 关于weights:初始每个记忆块都是相同的,学习过程中自适应变化(论文3.3)

代码实现如下:

class MemNet(nn.Module):
    def __init__(self, in_channels, channels, num_memblock, num_resblock):
        super(MemNet, self).__init__()
        self.feature_extractor = BNReLUConv(in_channels, channels, True)  # FENet
        self.reconstructor = BNReLUConv(channels, in_channels, True)      # ReconNet

        # 全部记忆块
        self.dense_memory = nn.ModuleList(
            [MemoryBlock(channels, num_resblock, i+1) for i in range(num_memblock)]
        )

        # 多监督的权重
        self.weights = nn.Parameter((torch.ones(1, num_memblock)/num_memblock), requires_grad=True)  

    #Multi-supervised MemNet architecture
    def forward(self, x):
        residual = x # 残差输出
        out = self.feature_extractor(x) # FENet输出
        w_sum=self.weights.sum(1)  # 权重总和
        mid_feat=[] # 存每个记忆块的短路径输出(短期)
        ys = [out]  # 存每个记忆块的长路径输出
        for memory_block in self.dense_memory:
            out = memory_block(out, ys)  # 每个记忆块的输出(短路经)
            mid_feat.append(out);   # 添加到list中

        # output1
        pred = (self.reconstructor(mid_feat[0])+residual)*self.weights.data[0][0]/w_sum
        for i in range(1,len(mid_feat)):
            # output1~outputM按元素相加
            pred = pred + (self.reconstructor(mid_feat[i])+residual)*self.weights.data[0][i]/w_sum

        return pred

2.3 训练

本节对应train.py。

实现细节:

  • 训练集:BSD500的训练集(200张)
  • 验证集:BSD500的验证集(100张)
  • 测试集:BSD500的测试集(200张)、S14
  • 噪声水平:σ=30,50,70
  • 输入:灰度图
  • 块大小:31×31,步长为21
  • 模型版本:基础结构和多监督结构
  • 模型参数:6个记忆块,每个块6个递归层(M6R6);自监督中α=1/(M+1);所有卷积层特征数为64;门控单元卷积核为1×1,其余为3×3;使用kaiming权重初始化
  • 优化器:SGD,动量0.9,权重衰减1e-4
  • batch_size:64
  • lr:初始0.1,每20个epoch减小10倍

</font color=‘red’>注:Pytorch使用SGD,在BSD200张训练集的情况下,batch_size为64,lr为0.1,会出现梯度消失。解决办法是增大batch_size(硬件不够好会爆显存)或降低学习率(从0.1开始逐渐下降是为了验证集上的指标曲线好看,由低到高。但是0.1作为高学习率,可能还没等lr降低就梯度消失了);或者更换为Adam优化器,然后选择较低的学习率(REDNet论文中推荐)。最后,训练参数没必要与原论文完全一样,因为pytorch和caffe之间可能有很大区别,无法用pytorch准确实现caffe的结果,以学习为主不必过于较真。

2.3.1 训练过程拆解

训练过程包括以下几个部分:

  • 参数设置:parser设置模型参数,见本文1.2
  • 定义模型:model = MemNet_BS(3, 64, 6, 6)
  • checkpoint设置:加载模型,从该模型的epoch继续训练
  • 损失函数定义:criterion = nn.MSELoss()
  • 优化器设置:optimizer = optim.Adam(model.parameters(), lr=opt.lr)optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.momentum, momentum=opt.weight_decay)
  • DataLoader导入数据集:dataloader = DataLoader(dataset=dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.threads, pin_memory=True, drop_last=True)
  • 训练过程框架:读取dataloader获得输入和标签,将输入送入模型,反向传播,梯度更新
  • 训练过程指标保存:results字典保存loss、psnr、ssim
  • 训练过程模型保存:torch.save

2.3.2 验证过程拆解

每一个epoch执行完后,要用验证集得到该epoch的指标。

  1. 模型转为验证模式:model.eval()
  2. 读取验证集,将输入送入模型得到preds
  3. 计算该epoch的模型输出和Groud-truth之间的PSNR和SSIM
  4. 判断并保存指标最优的模型
  5. 用pandas将指标保存在表格中

2.3.3 完整训练代码

import argparse
import os
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from memnet1 import MemNet
from memnet import MemNet_BS
from dataset import Dataset, EvalDataset, gen_patch
from utils import AverageMeter, adjust_lr
import copy
import numpy as np
import pandas as pd

from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr


cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--arch', type=str, default='MemNet_BS', help='MemNet_MS,MemNet_BS')
    parser.add_argument('--images_dir', type=str, default='datasets/train/BSD200')
    parser.add_argument('--clean_valid_dir', type=str, default='datasets/test/S14')
    parser.add_argument('--outputs_dir', type=str, default='weights')
    parser.add_argument('--gaussian_noise_level', type=str, default='50') # 30、50、70
    parser.add_argument('--downsampling_factor', type=str, default=None)
    parser.add_argument('--jpeg_quality', type=str, default=None)
    parser.add_argument('--patch_size', type=int, default=31)
    parser.add_argument('--stride', type=int, default=21)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_epochs', type=int, default=50)
    parser.add_argument('--start_epoch', type=int, default=0)   # 从第几轮开始
    parser.add_argument("--resume", default='', type=str)   # 从哪个权重模型继续训练
    parser.add_argument('--lr', type=float, default=1e-4)  # SGD 0.1, Adam 1e-4
    parser.add_argument('--lr-decay-steps', type=int, default=50)  # 多少轮后开始下降
    parser.add_argument('--lr-decay-gamma', type=float, default=0.1)    # 下降10倍
    parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4")
    parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
    parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4")
    parser.add_argument('--threads', type=int, default=8)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--epoch_save_num', type=int, default=1)  # 每多少轮保存指标
    opt = parser.parse_args()

    if opt.gaussian_noise_level is not None:
        opt.gaussian_noise_level = list(map(lambda x: int(x), opt.gaussian_noise_level.split(',')))

    if opt.downsampling_factor is not None:
        opt.downsampling_factor = list(map(lambda x: int(x), opt.downsampling_factor.split(',')))

    if opt.jpeg_quality is not None:
        opt.jpeg_quality = list(map(lambda x: int(x), opt.jpeg_quality.split(',')))

    if not os.path.exists(opt.outputs_dir):
        os.makedirs(opt.outputs_dir)

    torch.manual_seed(opt.seed)

    if opt.arch == 'MemNet_BS':
        model = MemNet_BS(3, 64, 6, 6)
        model = model.to(device)
    if opt.arch == 'MemNet_MS':
        model = MemNet(3, 64, 6, 6)
        model = model.to(device)

    criterion = nn.MSELoss()

    # 用SGD loss不下降
    # optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.momentum, momentum=opt.weight_decay)

    # Adam lr=0.1 会梯度消失
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))  # 显示信息
            checkpoint = torch.load(opt.resume) # 加载模型
            opt.start_epoch = checkpoint["epoch"] + 1   # 从最后一个检查点获取epoch加1,确定从哪一个epoch开始继续训练
            model.load_state_dict(checkpoint["model"].state_dict()) # 将模型状态字典加载到checkpoint中
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))



    data1 = gen_patch(opt.images_dir, opt.patch_size, opt.stride)
    dataset = Dataset(data1, opt.patch_size, opt.gaussian_noise_level, opt.downsampling_factor, opt.jpeg_quality)

    dataloader = DataLoader(dataset=dataset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.threads,
                            pin_memory=True,
                            drop_last=True)

    eval_dataset = EvalDataset(opt.clean_valid_dir, opt.gaussian_noise_level, opt.downsampling_factor, opt.jpeg_quality)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    # 验证过程变量:最优权重参数、最优epoch、最优psnr
    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    results = {'loss': [], 'psnr': [], 'ssim': []}

    for epoch in range(opt.start_epoch, opt.num_epochs):

        # Adjust learning rate
        lr = adjust_lr(optimizer, opt.lr, epoch, opt.lr_decay_steps, opt.lr_decay_gamma)

        for param_group in optimizer.param_groups:  # 遍历优化器的参数,更新学习率
            param_group["lr"] = lr

        running_results = {'batch_sizes': 0, 'loss': 0}
        model.train()

        epoch_losses = AverageMeter()

        with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}, lr={}'.format(epoch + 1, opt.num_epochs, param_group["lr"]))
            for data in dataloader:
                inputs, labels = data

                batch_size = inputs.size(0)
                running_results['batch_sizes'] += batch_size

                inputs = inputs.to(device)
                labels = labels.to(device)

                preds = model(inputs)

                loss = criterion(preds, labels)

                epoch_losses.update(loss.item(), len(inputs))

                optimizer.zero_grad()
                loss.backward()
                # nn.utils.clip_grad_norm_(model.parameters(), 0.01/lr)
                # nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
                optimizer.step()

                _tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                _tqdm.update(len(inputs))

        running_results['loss'] = epoch_losses.avg

        if (epoch + 1) % opt.epoch_save_num == 0:
            state = {"epoch": epoch, "model": model}
            torch.save(state,
                       os.path.join(opt.outputs_dir, '{}_epoch_{}_{}.pth'.format(opt.arch, epoch, opt.gaussian_noise_level)))

        # ************
        # 验证过程
        # ************
        model.eval()
        epoch_psnr = AverageMeter()  # 记录平均PSNR
        valing_results = {'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}  # 验证集结果字典

        for data in tqdm(eval_dataloader):
            inputs, labels = data
            batch_size = inputs.size(0)
            valing_results['batch_sizes'] += batch_size

            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                preds = model(inputs)

            # output = preds.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
            # output = pil_image.fromarray(output, mode='RGB')

            SR = preds.mul(255.0).cpu().numpy().squeeze(0)
            SR = np.clip(SR, 0.0, 255.0).transpose([1, 2, 0])
            SR_y = SR.astype(np.float32)[..., 0] / 255.
            # SR_y = utils.rgb2ycbcr(SR).astype(np.float32)[..., 0] / 255.

            hr_image = labels.mul(255.0).cpu().numpy().squeeze(0)
            hr_image = np.clip(hr_image, 0.0, 255.0).transpose([1, 2, 0])
            hr_y = hr_image.astype(np.float32)[..., 0] / 255.
            # hr_y = utils.rgb2ycbcr(hr_image).astype(np.float32)[..., 0] / 255.

            # epoch_ssim = calculate_ssim(SR, hr_image)
            # epoch_ssim = calculate_ssim(SR_y * 255, hr_y * 255)
            # epoch_ssim = calculate_ssim(SR_y, hr_y)
            epoch_psnr1 = compare_psnr(SR, hr_image, data_range=SR.max() - SR.min())
            epoch_ssim = compare_ssim(SR, hr_image, channel_axis=2, data_range=SR.max() - SR.min())

            valing_results['ssims'] += epoch_ssim * batch_size  # 更新SSIM
            valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']

            # epoch_psnr.update(calc_psnr(SR, hr_image))
            # epoch_psnr.update(calc_psnr(SR_y, hr_y))
            epoch_psnr.update(epoch_psnr1)

        print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
        print('eval ssim: {:.4f}'.format(valing_results['ssim']))
        valing_results['psnr'] = epoch_psnr.avg.item()  # psnr是tensor类型,只需要存里面的值

        # 得到最优的psnr
        if epoch_psnr.avg > best_psnr:
            best_epoch = epoch
            best_psnr = epoch_psnr.avg
            best_weights = copy.deepcopy(model.state_dict())

        # 存各种值
        results['loss'].append(running_results['loss'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])

        # 保存验证集数据
        if (epoch + 1) % opt.epoch_save_num == 0 and epoch != 0:
            data_frame = pd.DataFrame(
                data={'Loss': results['loss'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(opt.start_epoch, epoch + 1))
            data_frame.to_csv(opt.outputs_dir + '_srf_' + str(opt.gaussian_noise_level) + '_' + str(opt.arch) + '_train_results.csv',
                              index_label='Epoch')

    print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
    best_weight_path = 'best_' + str(opt.arch) + '_' + str(opt.gaussian_noise_level) + '.pth'
    torch.save(best_weights, os.path.join(opt.outputs_dir, best_weight_path))

2.4 测试

本节对应test_benchmark.py。

实现思路:

  1. 读取测试集文件夹,得到每张图像
  2. 读取训练好的模型文件
  3. 遍历每张图像,制作带噪图像作为模型输入
  4. 模型输出与原图像之间计算PSNR和SSIM
  5. 保存去噪后的图像

代码如下:

# ********************************************
# 测试BSD200,S14
# 保存去噪后的图像
# 计算PSNR和SSIM
# ********************************************

import argparse
import os
import io
import numpy as np
import PIL.Image as pil_image
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms
from memnet1 import MemNet
from memnet import MemNet_BS

import matplotlib.pyplot as plt
import glob
import random

# from skimage.measure import compare_psnr, compare_ssim # 老版本有报错
from skimage.io import imread, imsave
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--arch', type=str, default='MemNet_BS', help='MemNet_MS,MemNet_BS')
    parser.add_argument('--weights_path', type=str, default='weights/best_MemNet_MS_[70].pth')
    parser.add_argument('--images_dir', type=str, default='datasets/test/S14')
    parser.add_argument('--outputs_denoising_dir', type=str, default='data/S14_MemNet_MS_denoising_70_outputs')
    parser.add_argument('--outputs_plt_dir', type=str, default='data/S14_MemNet_MS_denoising_70_plt_outputs')
    parser.add_argument('--gaussian_noise_level', type=str, default='70')
    parser.add_argument('--jpeg_quality', type=int)
    parser.add_argument('--downsampling_factor', type=int)
    opt = parser.parse_args()

    if not os.path.exists(opt.outputs_denoising_dir):
        os.makedirs(opt.outputs_denoising_dir)

    if not os.path.exists(opt.outputs_plt_dir):
        os.makedirs(opt.outputs_plt_dir)

    # MemNet_BS的30和50是用MemNet1使用Base MemNet architecture所得
    if opt.arch == 'MemNet_BS':
        model = MemNet_BS(3, 64, 6, 6)
        model = model.to(device)
    if opt.arch == 'MemNet_MS':
        model = MemNet(3, 64, 6, 6)
        model = model.to(device)

    # 不使用最优模型
    # model = torch.load(opt.weights_path)["model"]

    # 使用最优模型,二者的区别是保存的形式不同
    state_dict = model.state_dict()
    for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model = model.to(device)
    model.eval()

    image_list = glob.glob(opt.images_dir + "/*.*")

    benchmark_len = len(image_list)

    sum_psnr = 0.0
    sum_ssim = 0.0

    for i in range(benchmark_len):
        filename = os.path.basename(image_list[i]).split('.')[0]
        descriptions = ''
        print("image:", filename)
        input = pil_image.open(image_list[i]).convert('RGB')
        # input = imread(image_list[i])

        GT = input
        GT_cal = np.array(input).astype(np.float32) / 255.0

        if opt.gaussian_noise_level is not None and type(opt.gaussian_noise_level)!=list:
            opt.gaussian_noise_level = list(map(lambda x: int(x), opt.gaussian_noise_level.split(',')))

        if len(opt.gaussian_noise_level) == 1:
            sigma = opt.gaussian_noise_level[0]
        else:
            sigma = random.randint(opt.gaussian_noise_level[0], opt.gaussian_noise_level[1])

        # 加对应噪声水平的噪声
        # noise = np.random.normal(0.0, sigma, input.shape).astype(np.float32)
        noise = np.random.normal(0.0, sigma, (input.height, input.width, 3)).astype(np.float32)
        input = np.array(input).astype(np.float32) + noise

        # # 图像本来就有噪声,不加噪
        # input = np.array(input).astype(np.float32)

        descriptions += '_noise_l{}'.format(sigma)
        # pil_image.fromarray(input.clip(0.0, 255.0).astype(np.uint8)).save(os.path.join(opt.outputs_denoising_dir, '{}{}.png'.format(filename, descriptions)))
        input /= 255.0
        noisy_input = input

        if opt.jpeg_quality is not None:
            buffer = io.BytesIO()
            input.save(buffer, format='jpeg', quality=opt.jpeg_quality)
            input = pil_image.open(buffer)
            descriptions += '_jpeg_q{}'.format(opt.jpeg_quality)
            input.save(os.path.join(opt.outputs_dir, '{}{}.png'.format(filename, descriptions)))
            input = np.array(input).astype(np.float32)
            input /= 255.0

        if opt.downsampling_factor is not None:
            original_width = input.width
            original_height = input.height
            input = input.resize((input.width // opt.downsampling_factor,
                                  input.height // opt.downsampling_factor),
                                 resample=pil_image.BICUBIC)
            input = input.resize((original_width, original_height), resample=pil_image.BICUBIC)
            descriptions += '_sr_s{}'.format(opt.downsampling_factor)
            input.save(os.path.join(opt.outputs_dir, '{}{}.png'.format(filename, descriptions)))
            input = np.array(input).astype(np.float32)
            input /= 255.0

        input = transforms.ToTensor()(input).unsqueeze(0).to(device)

        with torch.no_grad():
            pred = model(input)

        output = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
        denoising_output = output / 255.0
        output = pil_image.fromarray(output, mode='RGB')
        output.save(os.path.join(opt.outputs_denoising_dir, '{}_{}_{}.png'.format(filename, descriptions, opt.arch)))

        # 计算指标skimage
        psnr = compare_psnr(GT_cal, denoising_output, data_range=GT_cal.max() - GT_cal.min())
        ssim = compare_ssim(GT_cal, denoising_output, channel_axis=2, data_range=GT_cal.max() - GT_cal.min())
        sum_psnr += psnr
        sum_ssim += ssim

        # # 自己定义指标
        # ssim = calculate_ssim(GT_cal * 255, denoising_output * 255)
        # sum_ssim += ssim
        #
        # psnr = calc_psnr(GT_cal, denoising_output)
        # sum_psnr += psnr

        # 对比图
        fig, axes = plt.subplots(1, 3)
        # 关闭坐标轴
        for ax in axes:

            ax.axis('off')

        # 在每个子图中显示对应的图像
        axes[0].imshow(GT)
        axes[0].set_title('Ground-Truth')
        axes[1].imshow(noisy_input)
        axes[1].set_title('noisy')
        axes[2].imshow(output)
        axes[2].set_title('{}'.format(opt.arch))

        # # 只有噪声图像和去噪后图像
        # axes[0].imshow(noisy_input)
        # axes[0].set_title('noisy_image')
        # axes[1].imshow(output)
        # axes[1].set_title('DnCNN')

        # 保存图像
        plt.savefig(os.path.join(opt.outputs_plt_dir, '{}_plt_x{}_{}.png'.format(filename, opt.gaussian_noise_level, opt.arch)),
                    bbox_inches='tight', dpi=600)

    print('PSNR: {:.2f}'.format(sum_psnr / benchmark_len))
    print('SSIM: {:.4f}'.format(sum_ssim / benchmark_len))

测试集“14 images”和BSD200在噪声水平为30,50,70情况下的平均PSNR/SSIM:

DatasetNoise MemNet(Paper) MemNet_BS(本文复现) MemNet_MS(本文复现)
14 images3029.22/0.844430.73/0.893028.46/0.7848
5026.91/0.777528.17/0.833623.30/0.5473
7025.43/0.726026.41/0.770922.39/0.4941
BSD2003028.04/0.805330.76/0.885328.87/0.7921
5025.86/0.720228.28/0.818723.85/0.5326
7024.53/0.660826.63/0.750623.06/0.4884

Basic MemNet architecture的表现优于Multi-supervised MemNet architecture,memnet.py中的基础结构是原作者认可的Pytorch版本结构,此结构性能最好。原论文是caffe实现的,本文是Pytorch实现的,以及其他各种综合因素,所以指标计算有差别。

去噪效果展示:

σ = 30 :
在这里插入图片描述
σ = 50 :

在这里插入图片描述
σ=70:
在这里插入图片描述
和REDNet对比(仔细看MemNet的去噪效果更好一些):
在这里插入图片描述

三、总结与思考

  1. 多监督模型在σ=50和70时的模型可能没有训练好,导致指标偏低。
  2. 网络结构用了BN层,根据超分领域的经验,不用BN层是不是能有提升。
  3. 读者可以自己训练彩色图像。

完整代码和训练好的模型权重文件下载链接

完整代码和训练好的模型权重文件下载链接 :图像去噪MemNet的Pytorch复现代码,包含计算PSNR/SSIM代码以及训练好的模型文件,可以直接使用,训练自己的数据集


至此本文结束。

如果本文对你有所帮助,请点赞收藏,创作不易,感谢您的支持!

  • 6
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
U-Net是一种深度学习模型,最初用于生物医学像分割,但它也可以应用于图像去噪任务。在PyTorch复现U-Net,你可以按照以下步骤操作: 1. **安装依赖**:首先确保已经安装了PyTorch及其相关的库,如torchvision。如果需要,可以运行`pip install torch torchvision`. 2. **网络结构搭建**:创建一个U-Net模型的核心部分,它包括编码器(逐渐降低分辨率,提取特征)和解码器(逐步增加分辨率,恢复细节)。可以参考论文《Image Segmentation through Deep Learning》的架构。 ```python import torch.nn as nn from torch.nn import Conv2d, MaxPool2d, UpSample class UNetBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): super(UNetBlock, self).__init__() self.encoder = nn.Sequential( Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), nn.ReLU(), Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding), nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding) ) def forward(self, x): skip_connection = x x = self.encoder(x) x = self.decoder(x) return torch.cat((x, skip_connection), dim=1) # 构建完整的U-Net模型 def create_unet(input_channels, num_classes): unet = nn.Sequential( nn.Conv2d(input_channels, 64, 3, padding=1), nn.MaxPool2d(2, 2), UNetBlock(64, 128), nn.MaxPool2d(2, 2), UNetBlock(128, 256), nn.MaxPool2d(2, 2), UNetBlock(256, 512), nn.MaxPool2d(2, 2), UNetBlock(512, 1024), nn.Upsample(scale_factor=2), UNetBlock(1024, 512), nn.Upsample(scale_factor=2), UNetBlock(512, 256), nn.Upsample(scale_factor=2), UNetBlock(256, 128), nn.Upsample(scale_factor=2), nn.Conv2d(128, num_classes, 1) ) return unet ``` 3. **训练和应用**:准备噪声像数据、对应干净像的数据集,然后定义损失函数(如MSE或SSIM)、优化器,并开始训练训练完成后,对新的噪声像进行前向传播以获得去噪后的结果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

十小大

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值