【图像去噪】论文复现:适合新手小白的Pytorch版本RIDNet复现!轻松跑通训练和测试代码!RIDNet网络结构实现拆解!简单修改路径即可训练自己的数据集!模型训练推理测试全流程讲解!

第一次来请先看【专栏介绍文章】:

本文亮点:

  • 跑通训练和测试代码,轻松运行,保证无任何运行问题
  • RIDNet网络结构实现拆解(傻瓜式讲解),数据处理、模型训练和验证、推理测试全流程讲解,无论是科研还是应用,新手小白都能看懂,学习阅读毫无压力,去噪入门必看
  • 理论和源码结合,进一步加深理解算法原理、明确训练和测试流程;
  • 更换路径和相关参数即可训练自己的图像数据集,无论是灰度图还是RGB图均可;
  • 去噪前后图像对比,噪声对比
  • 可计算测试集评估指标。补充了PSNR和SSIM的计算代码。


前言

论文题目:Real Image Denoising with Feature Attention —— 具有特征注意的真实图像去噪

论文地址:Real Image Denoising with Feature Attention

论文源码:https://github.com/saeed-anwar/RIDNet

对应的论文精读:【图像去噪】论文精读:Real Image Denoising with Feature Attention(RIDNet)

另一个RIDNet的复现版本:https://github.com/eremo2002/Pytorch-RIDNet

在这里插入图片描述

不使用官方源码的原因:官方源码使用PyTorch0.4,Python3.6,Cuda9.0。本文写于2024年,PyTorch>1.1版本,Cuda基本上大于10了。PyTorch1.1及以上的API与PyTorch0.4的差别很大,使用源码会有很多报错,并且我的工作环境是Cuda12,Pytorch和Python倒是可以退回到老版本,但是显卡不行。故放弃源码复现,

本文复现代码:在CBDNet的复现代码的基础上,将模型由CBDNet改为RIDNet,其他代码基本保持不变。CBDNet复现文章链接:【图像去噪】论文复现:适合新手小白的Pytorch版本CBDNet复现!轻松跑通训练和测试代码!简单修改路径即可训练自己的数据集!代码详细注释!数据处理、模型训练和验证、推理测试全流程讲解!

一、跑通代码 (Quick Start)

项目文件说明:

  • data:测试单张图像文件夹
  • datasets:数据集所在文件夹
  • weights:训练模型保存位置
  • loader.py:封装数据集
  • predict.py:测试单张图像去噪视觉效果
  • RIDNet.py:RIDNet模型实现
  • test_benchmark.py:计算测试集PSNR/SSIM,保存测试集图像去噪结果
  • test_noise.py:测试图像加噪效果
  • train.py:训练RIDNet
  • utils.py:工具类脚本,包含一些图像操作

1.1 数据集准备

同CBDNet的复现。

1.2 将CBDNet模型替换为RIDNet

模型文件处添加RIDNet.py。

import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, out_channels//reduction, 1, 1, 0)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels//reduction, in_channels, 1, 1, 0)
        self.sigmoid2 = nn.Sigmoid()

    def forward(self, x):
        gap = self.gap(x)
        x_out = self.conv1(gap)
        x_out = self.relu1(x_out)
        x_out = self.conv2(x_out)
        x_out = self.sigmoid2(x_out)
        x_out = x_out * x        
        return x_out
    

class EAM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, reduciton=4):
        super(EAM, self).__init__()

        # Merge and run unit
        self.path1_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
        self.path1_relu1 = nn.ReLU()
        self.path1_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=2, dilation=2)
        self.path1_relu2 = nn.ReLU()

        self.path2_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=3, dilation=3)
        self.path2_relu1 = nn.ReLU()
        self.path2_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=4, dilation=4)
        self.path2_relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(in_channels*2, out_channels, kernel_size, stride=1, padding=1)
        self.relu3 = nn.ReLU()

        # Residual block
        self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu4 = nn.ReLU()
        self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu5 = nn.ReLU()

        # Enhance Residual block
        self.conv6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu6 = nn.ReLU()
        self.conv7 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu7 = nn.ReLU()
        self.conv8 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.relu8 = nn.ReLU()

        # Channel Attention
        self.ca = ChannelAttention(in_channels, out_channels, reduction=16)

    def forward(self, x):
        # Merge and run block        
        x1 = self.path1_conv1(x)
        x1 = self.path1_relu1(x1)        
        x1 = self.path1_conv2(x1)
        x1 = self.path1_relu2(x1)

        x2 = self.path2_conv1(x)
        x2 = self.path2_relu1(x2)
        x2 = self.path2_conv2(x2)
        x2 = self.path2_relu2(x2)

        x3 = torch.cat([x1, x2], dim=1)
        x3 = self.conv3(x3)
        x3 = self.relu3(x3)
        x3 = x3 + x

        # Residual block
        x4 = self.conv4(x3)
        x4 = self.relu4(x4)
        x4 = self.conv5(x4)
        x5 = x4 + x3
        x5 = self.relu5(x5)

        # Enhance Residual block
        x6 = self.conv6(x5)
        x6 = self.relu6(x6)
        x7 = self.conv7(x6)
        x7 = self.relu7(x7)
        x8 = self.conv8(x7)
        x8 = x8 + x5
        x8 = self.relu8(x8)

        # CA
        x_ca = self.ca(x8)
        
        return x_ca + x





class RIDNet(nn.Module):
    def __init__(self, in_channels, out_channels, num_feautres):
        super(RIDNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, num_feautres, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU(inplace=False)

        self.eam1 = EAM(in_channels=num_feautres, out_channels=num_feautres)
        self.eam2 = EAM(in_channels=num_feautres, out_channels=num_feautres)
        self.eam3 = EAM(in_channels=num_feautres, out_channels=num_feautres)
        self.eam4 = EAM(in_channels=num_feautres, out_channels=num_feautres)

        self.last_conv = nn.Conv2d(num_feautres, out_channels, kernel_size=3, stride=1, padding=1, dilation=1)
        
        self.init_weights()

    def forward(self, x):        
        x1 = self.conv1(x) # feature extraction module
        x1 = self.relu1(x1)
        
        x_eam = self.eam1(x1)                
        x_eam = self.eam2(x_eam)
        x_eam = self.eam3(x_eam)
        x_eam = self.eam4(x_eam)
        
        x_lsc = x_eam + x1 # Long skip connection
        x_out = self.last_conv(x_lsc) # reconstruction module
        x_out = x_out + x # Long skip connection

        return x_out

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)            
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

train.py和predict.py导入模型的位置改为:

from RIDNet import RIDNet

调用模型的位置改为:

model = RIDNet(3, 3, 64) # 如果训练灰度图,将输入通道改为1

更省事的办法是模型名称不改,直接改里面的模型结构,然后调用的时候添加参数即可。

1.3 训练和测试

同CBDNet的复现。

训练过程lr与loss的变化:

在这里插入图片描述

RNI15数据集中的“Audrey_Hepburn”结果展示,对比CBDNet:

在这里插入图片描述
可以看到,局部放大区域RIDNet的去噪效果更好。其他数据结果大家自行测试。

注:如果在CBDNet的代码基础上修改有困难或有报错,也可以在本文末尾直接下载改好的代码(包含训练好的模型文件)。

二、代码解析

2.1 RIDNet网络结构

本节对应RIDNet.py。

2.1.1 网络结构细节回顾

在这里插入图片描述

整体结构:输入带噪声图像→特征提取层(一个卷积层)→ 四个EAM块 → 重建层(一个卷积层)→ 输出去噪后图像。

  • 输入图像与重建层输出进行长跳跃链接(LSC)
  • 特征提取层输出与EAM输出长跳跃链接(LSC)

EAM结构:Merge and run unit、RB、ERB、CA(该部分每个Conv后都跟一个ReLU)
请添加图片描述

  • Merge and run unit:下面路径是两个dilation为1和2的扩张卷积层;上面路径是两个dilation为3和4的扩张卷积层;两个路径的输出再经过一个卷积层,最后与输入局部跳跃连接(LC)后得到该部分的输出
  • RB:两个卷积层+LC
  • ERB:三个卷积层+ LC
  • CA:全局平均池化(GAP)→ Conv+ReLU+Conv → Sigmoid → 与CA的输入按元素乘(self-gating)→ 与EAM的输入做短跳跃连接(SSC)

其他细节:【ERB最后一个卷积层的卷积核】和【CA中的卷积核】大小为1×1,其他均为3×3;重建层后不跟ReLU,其他卷积层后都跟ReLU。

小总结:独立的结构从小到大,输入与输出都做了残差连接。

2.1.2 网络结构实现拆解

从小到大实现每个部分。

Merge and run unit:

  • 结构定义:
# Merge and run unit
self.path1_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
self.path1_relu1 = nn.ReLU()
self.path1_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=2, dilation=2)
self.path1_relu2 = nn.ReLU()

self.path2_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=3, dilation=3)
self.path2_relu1 = nn.ReLU()
self.path2_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=4, dilation=4)
self.path2_relu2 = nn.ReLU()

self.conv3 = nn.Conv2d(in_channels*2, out_channels, kernel_size, stride=1, padding=1)
self.relu3 = nn.ReLU()
  • 传播过程
# Merge and run block        
x1 = self.path1_conv1(x)
x1 = self.path1_relu1(x1)        
x1 = self.path1_conv2(x1)
x1 = self.path1_relu2(x1)

x2 = self.path2_conv1(x)
x2 = self.path2_relu1(x2)
x2 = self.path2_conv2(x2)
x2 = self.path2_relu2(x2)

x3 = torch.cat([x1, x2], dim=1)
x3 = self.conv3(x3)
x3 = self.relu3(x3)
x3 = x3 + x

残差块RB:

  • 结构定义:
# Residual block
self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu4 = nn.ReLU()
self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu5 = nn.ReLU()
  • 传播过程
# Residual block
x4 = self.conv4(x3)
x4 = self.relu4(x4)
x4 = self.conv5(x4)
x5 = x4 + x3
x5 = self.relu5(x5)

增强残差块ERB:

  • 结构定义:
# Enhance Residual block
self.conv6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu6 = nn.ReLU()
self.conv7 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu7 = nn.ReLU()
self.conv8 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu8 = nn.ReLU()
  • 传播过程
# Enhance Residual block
x6 = self.conv6(x5)
x6 = self.relu6(x6)
x7 = self.conv7(x6)
x7 = self.relu7(x7)
x8 = self.conv8(x7)
x8 = x8 + x5
x8 = self.relu8(x8)

通道注意力CA:

  • 结构定义:
self.gap = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, out_channels//reduction, 1, 1, 0)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels//reduction, in_channels, 1, 1, 0)
self.sigmoid2 = nn.Sigmoid()
  • 传播过程
gap = self.gap(x)
x_out = self.conv1(gap)
x_out = self.relu1(x_out)
x_out = self.conv2(x_out)
x_out = self.sigmoid2(x_out)
x_out = x_out * x     

EAM整体结构:

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, out_channels//reduction, 1, 1, 0)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels//reduction, in_channels, 1, 1, 0)
        self.sigmoid2 = nn.Sigmoid()

    def forward(self, x):
        gap = self.gap(x)
        x_out = self.conv1(gap)
        x_out = self.relu1(x_out)
        x_out = self.conv2(x_out)
        x_out = self.sigmoid2(x_out)
        x_out = x_out * x        
        return x_out
    

class EAM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, reduciton=4):
        super(EAM, self).__init__()

        # Merge and run unit
        self.path1_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
        self.path1_relu1 = nn.ReLU()
        self.path1_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=2, dilation=2)
        self.path1_relu2 = nn.ReLU()

        self.path2_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=3, dilation=3)
        self.path2_relu1 = nn.ReLU()
        self.path2_conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=4, dilation=4)
        self.path2_relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(in_channels*2, out_channels, kernel_size, stride=1, padding=1)
        self.relu3 = nn.ReLU()

        # Residual block
        self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu4 = nn.ReLU()
        self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu5 = nn.ReLU()

        # Enhance Residual block
        self.conv6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu6 = nn.ReLU()
        self.conv7 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu7 = nn.ReLU()
        self.conv8 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.relu8 = nn.ReLU()

        # Channel Attention
        self.ca = ChannelAttention(in_channels, out_channels, reduction=16)

    def forward(self, x):
        # Merge and run block        
        x1 = self.path1_conv1(x)
        x1 = self.path1_relu1(x1)        
        x1 = self.path1_conv2(x1)
        x1 = self.path1_relu2(x1)

        x2 = self.path2_conv1(x)
        x2 = self.path2_relu1(x2)
        x2 = self.path2_conv2(x2)
        x2 = self.path2_relu2(x2)

        x3 = torch.cat([x1, x2], dim=1)
        x3 = self.conv3(x3)
        x3 = self.relu3(x3)
        x3 = x3 + x

        # Residual block
        x4 = self.conv4(x3)
        x4 = self.relu4(x4)
        x4 = self.conv5(x4)
        x5 = x4 + x3
        x5 = self.relu5(x5)

        # Enhance Residual block
        x6 = self.conv6(x5)
        x6 = self.relu6(x6)
        x7 = self.conv7(x6)
        x7 = self.relu7(x7)
        x8 = self.conv8(x7)
        x8 = x8 + x5
        x8 = self.relu8(x8)

        # CA
        x_ca = self.ca(x8)
        
        return x_ca + x

RIDNet整体结构:

class RIDNet(nn.Module):
    def __init__(self, in_channels, out_channels, num_feautres):
        super(RIDNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, num_feautres, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU(inplace=False)

        self.eam1 = EAM(in_channels=num_feautres, out_channels=num_feautres)
        self.eam2 = EAM(in_channels=num_feautres, out_channels=num_feautres)
        self.eam3 = EAM(in_channels=num_feautres, out_channels=num_feautres)
        self.eam4 = EAM(in_channels=num_feautres, out_channels=num_feautres)

        self.last_conv = nn.Conv2d(num_feautres, out_channels, kernel_size=3, stride=1, padding=1, dilation=1)
        
        self.init_weights()

    def forward(self, x):        
        x1 = self.conv1(x) # feature extraction module
        x1 = self.relu1(x1)
        
        x_eam = self.eam1(x1)                
        x_eam = self.eam2(x_eam)
        x_eam = self.eam3(x_eam)
        x_eam = self.eam4(x_eam)
        
        x_lsc = x_eam + x1 # Long skip connection
        x_out = self.last_conv(x_lsc) # reconstruction module
        x_out = x_out + x # Long skip connection

        return x_out

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)            
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

注:等复现的论文多了之后,网络结构就不这么细讲了。前几篇复现还是尽量详细,力求快速入门,养成给定网络结构图和文字说明后,能马上用pytorch搭建出网络结构的能力。

2.2 其他代码

与上一篇文章CBDNet类似,只是将CBDNet模型改为了RIDNet,其他基本没变。详情请看CBDNet的复现:图像去噪】论文复现:适合新手小白的Pytorch版本CBDNet复现!轻松跑通训练和测试代码!简单修改路径即可训练自己的数据集!代码详细注释!数据处理、模型训练和验证、推理测试全流程讲解!

三、思考与补充

3.1 思考

在这里插入图片描述
原论文图1中,CBDNet是如何得到这个效果的,百思不得其解。


对于模型来说,Attention已经证明有效,那么就有很多可以魔改的创新方向,即结合新的Attention模块,组合类似的块来去噪。不一定非要做真实图像的盲去噪,高斯白噪下的效果应该更好。

3.2 补充PSNR和SSIM的计算代码

同CBDNet的补充代码,有相同的test_benchmark.py文件。

相同测试集下,RIDNet的结果要稍低一些:

PSNR: 38.49
SSIM: 0.9540

原因是,无论是CBDNet还是RIDNet,都没有验证集,所得的模型都是取最后一个epoch的模型,那么随机性就比较大,谁的指标更高不一定。可以添加验证集来保存最高指标的模型。但指标的高低并不代表图像重建的好坏,视觉效果上RIDNet是比CBDNet要更好,无论是纹理细节还是高频的边缘,重建效果都是RIDNet更优。

代码链接:图像去噪RIDNet的Pytorch复现代码,包含计算PSNR/SSIM代码以及训练好的模型文件,可以直接用于真实图像去噪


至此本文结束。

如果本文对你有所帮助,请点赞收藏,创作不易,感谢您的支持!

  • 12
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

十小大

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值