【图像去噪】论文复现:支持任意大小的图像输入!四十多行实现Pytorch极简版本的IRCNN,各种参数和测试集平均PSNR结果与论文一致!

请先看【专栏介绍文章】:【图像去噪(Image Denoising)】关于【图像去噪】专栏的相关说明,包含适配人群、专栏简介、专栏亮点、阅读方法、定价理由、品质承诺、关于更新、去噪概述、文章目录、资料汇总、问题汇总(更新中)

完整代码和训练好的模型权重文件下载链接见本文底部,订阅专栏免费获取!

本文亮点:

  • 跑通训练和测试代码,轻松运行,按步骤执行保证无任何运行问题
  • Pytorch实现IRCNN,参数设置、测试结果与原论文一致,简单易懂
  • 更换路径和参数即可训练自己的图像数据集支持灰度图和RGB图
  • 包含训练好的模型文件(共3个,灰度图下对应3个噪声level=15,25,50),可直接推理运行得到去噪后的图像结果以及评价指标PSNR/SSIM
  • 数据处理、模型训练和验证、推理测试全流程讲解,无论是科研还是应用,新手小白都能看懂,学习阅读毫无压力,去噪入门必看


前言

论文题目:Learning Deep CNN Denoiser Prior for Image Restoration —— 学习深度CNN降噪先验用于图像重建

论文地址:Learning Deep CNN Denoiser Prior for Image Restoration

论文源码:https://github.com/cszn/ircnn

对应的论文精读:【图像去噪】论文精读:Learning Deep CNN Denoiser Prior for Image Restoration(IRCNN)

完整源码是matlab版本的,本文按照前面文章的风格,复现一个简单版本的Pytorch代码。

一、跑通代码 (Quick Start)

项目文件说明:
在这里插入图片描述

  • data:去噪后图像结果保存位置
  • datasets:数据集所在文件夹
  • Plt:训练过程指标曲线可视化位置(Loss、PSNR、SSIM与Epoch关系曲线)
  • weights:训练模型保存位置
  • dataset.py:封装数据集
  • draw_evaluation.py:绘制指标曲线
  • model.py:IRCNN模型实现
  • test.py:计算测试集指标;保存去噪后图像
  • train.py:训练IRCNN

1.1 数据集准备

按如下顺序执行:

  • 训练集:BSD400+ImageNet验证集中的400张+Waterloo Exploration Database4744张,共5544张图像,转为灰度图后,放在datasets/Train_gray
  • 测试集/验证集:BSD68,路径为datasets/BSD68
  • 制作.h5数据集:设置prepare.py参数
    parser.add_argument('--arch', type=str, default='IRCNN')
    parser.add_argument('--images-dir', type=str, default='datasets/Train_gray')	# 训练集路径
    parser.add_argument('--is_gray', type=str, default=True)	#	是否是灰度图
    parser.add_argument('--output-path', type=str, default='datasets/Train_25.h5')	# 制作的.h5
    parser.add_argument('--patch-size', type=int, default=35)	# 块大小
    parser.add_argument('--stride', type=int, default=35)	# 论文没说步长就是不重叠切块
    parser.add_argument('--sigma', type=int, default=25) # 噪声方差,可选为15,25,50
    parser.add_argument('--eval', default=False, action='store_true')	# 制作验证集设置为True
  • 执行prepare.py,依次制作sigma=15,25,50的数据集。执行完毕后,datasets文件中会出现制作好的训练集和验证集h5文件。训练集大小约28G,验证集约为80M。

注:如果训练自己的数据集,放在对应路径下即可。

1.2 训练

设置train.py参数:

    parser.add_argument('--arch', type=str, default='IRCNN', help='IRCNN')
    parser.add_argument('--images_dir', type=str, default='datasets/Train_50.h5') # 训练集路径
    parser.add_argument('--is_gray', type=str, default=True)    # 训练集是否是灰度图
    parser.add_argument('--clean_valid_dir', type=str, default='datasets/Val_50.h5') # 验证集路径
    parser.add_argument('--outputs_dir', type=str, default='weights') # 保存模型文件夹
    parser.add_argument('--gaussian_noise_level', type=str, default='50') # 15,25,50
    parser.add_argument('--downsampling_factor', type=str, default=None)
    parser.add_argument('--jpeg_quality', type=int, default=None)
    parser.add_argument('--patch_size', type=int, default=35) # 图像块大小
    parser.add_argument('--batch_size', type=int, default=256)   # bs
    parser.add_argument('--num_epochs', type=int, default=100)   # 总epoch
    parser.add_argument('--start_epoch', type=int, default=0)  # 从第几轮开始继续训练
    parser.add_argument("--resume", default='', type=str)  # 从哪个权重模型继续训练
    parser.add_argument('--lr', type=float, default=1e-3)   # 学习率
    parser.add_argument('--lr-decay-steps', type=int, default=50)  # 多少轮后开始下降
    parser.add_argument('--lr-decay-gamma', type=float, default=0.1)    # 下降一半
    parser.add_argument('--threads', type=int, default=8)   # num_workers
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--epoch_save_num', type=int, default=1)  # 每多少轮保存指标

训练参数与论文基本一致,学习率初始为1e-3,50个epoch后降为1e-4,共训练100个epochs。实测下来这样比较平滑:

在这里插入图片描述

1.3 测试

设置test.py参数:

   parser.add_argument('--weights_path', type=str, default='weights/best_IRCNN_[50].pth')
    parser.add_argument('--images_dir', type=str, default='datasets/BSD68')
    parser.add_argument('--is_gray', type=str, default=True)
    parser.add_argument('--outputs_denoising_dir', type=str, default='data/BSD68_denoising_50_IRCNN')	# 去噪后图像
    parser.add_argument('--outputs_plt_dir', type=str, default='data/BSD68_denoising_50_plt_IRCNN')	 # 对比图
    parser.add_argument('--gaussian_noise_level', type=str, default='50')

执行test.py,data文件夹下会保存结果图像,控制台会输出PSNR/SSIM。

结果展示:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
BSD68上的平均PSNR:

Datasets σ IRCNN(paper) IRCNN(ours)
BSD681531.6331.61
2529.1529.13
5026.1926.10

二、代码解析

2.1 数据预处理

2.1.1 制作h5数据集

本节对应prepare.py

  • 功能:制作h5格式的训练集和验证集
  • 具体实现:训练集按大小和步长裁剪,将“干净”图像和加噪图像成对存储到h5文件中,并应用数据增强;验证集不裁剪不增强,其余一致。
  • 制作成.h5的好处:提升图像读取的效率,图像恢复任务一般要将图像数据裁剪成若干小块,如果直接读取图像效率会低,直接封装在一个h5文件中,模型训练前只读取该h5文件即可。它是图像恢复领域比较通用的做法,超分专栏中我们也经常使用。

关键代码展示:

def train(args):
    h5_file = h5py.File(args.output_path, 'w')

    noisy_patches = []
    clean_patches = []

    clean_list = glob.glob(args.images_dir + "/*.*")

    for i in range(len(clean_list)):
        filename = os.path.basename(clean_list[i]).split('.')[0]
        print("image:", filename)

        if args.is_gray:
            clean_image = pil_image.open(clean_list[i]).convert('L')
            noisy_image = clean_image.copy()
            gaussian_noise = np.zeros((clean_image.height, clean_image.width), dtype=np.float32)
            gaussian_noise += np.random.normal(0.0, args.sigma, (clean_image.height, clean_image.width)).astype(
                np.float32)
        else:
            clean_image = pil_image.open(clean_list[i]).convert('RGB')
            noisy_image = clean_image.copy()
            gaussian_noise = np.zeros((clean_image.height, clean_image.width, 3), dtype=np.float32)
            gaussian_noise += np.random.normal(0.0, args.sigma, (clean_image.height, clean_image.width, 3)).astype(
                np.float32)


        clean_image = np.array(clean_image).astype(np.float32)
        noisy_image = np.array(noisy_image).astype(np.float32)
        noisy_image += gaussian_noise

        for i in range(0, clean_image.shape[0] - args.patch_size + 1, args.stride):
            for j in range(0, clean_image.shape[1] - args.patch_size + 1, args.stride):
                clean_patch = clean_image[i:i + args.patch_size, j:j + args.patch_size]
                noisy_patch = noisy_image[i:i + args.patch_size, j:j + args.patch_size]
                clean_patches.append(clean_patch)
                noisy_patches.append(noisy_patch)
                for m in range(0, 1):
                    clean_patch_aug, noisy_patch_aug = data_aug(clean_patch, noisy_patch, mode=np.random.randint(0, 8))
                    clean_patches.append(clean_patch_aug)
                    noisy_patches.append(noisy_patch_aug)

    noisy_patches = np.array(noisy_patches)
    clean_patches = np.array(clean_patches)

    h5_file.create_dataset('lr', data=noisy_patches)
    h5_file.create_dataset('hr', data=clean_patches)

    h5_file.close()

完整代码请看文末提供的完整项目。

2.1.2 封装数据集

本节对应dataset.py。

  • 功能:读取上一小节制作好的h5数据集,然后将维度更改为训练前需要的维度,并归一化
  • 由于DataLoader可以自动将np格式的图像转为Tensor(chw→bchw),所以在送入DataLoader前,灰度图维度变化为hw→1hw,RGB图维度变化为hwc→chw

代码如下:

class TrainDataset(Dataset):
    def __init__(self, h5_file):
        super(TrainDataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return 256 * 4000
            # return len(f['lr'])

由于本例是灰度图,所以只需要扩展第一个维度。如果读取的是RGB图,那么就修改为:

np.transpose(f['lr'][idx] / 255., (2, 0, 1)), np.transpose(f['hr'][idx] / 255., (2, 0, 1))

2.2 网络结构

本节对应model.py。
在这里插入图片描述
网络结构比较简单,共7层,第一层Conv+ReLU,中间五层Conv+BN+ReLU,最后一层只有Conv。Conv是膨胀卷积,因子为1234321。观察到有通用的"Conv+BN+ReLU"结构,很自然的想到将其单独封装成一个模块。

IRCNN的实现如下:

import torch
import torch.nn as nn

class Conv_BN_Relu(nn.Module):
    def __init__(self, in_channels, out_channels, padding, dilation):
        super(Conv_BN_Relu, self).__init__()
        self.module = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        output = self.module(x)
        return output


class IRCNN(nn.Module):
    def __init__(self, in_channels=1, num_features=64, out_channels=1):
        super(IRCNN, self).__init__()
        L =[]
        L.append(nn.Conv2d(in_channels=in_channels, out_channels=num_features, kernel_size=3, stride=1, padding=1, dilation=1, bias=True))
        L.append(nn.ReLU(inplace=True))

        L.append(Conv_BN_Relu(in_channels=num_features, out_channels=num_features, padding=2, dilation=2))
        L.append(Conv_BN_Relu(in_channels=num_features, out_channels=num_features, padding=3, dilation=3))
        L.append(Conv_BN_Relu(in_channels=num_features, out_channels=num_features, padding=4, dilation=4))
        L.append(Conv_BN_Relu(in_channels=num_features, out_channels=num_features, padding=3, dilation=3))
        L.append(Conv_BN_Relu(in_channels=num_features, out_channels=num_features, padding=2, dilation=2))

        L.append(nn.Conv2d(in_channels=num_features, out_channels=out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=True))
        self.model = nn.Sequential(*L)

        self._initial_weights()

    def _initial_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        n = self.model(x)
        return x-n

权重初始化:Conv使用kaiming初始化,BN的gamma缩放为1,beta偏置为0;
和DnCNN一样,网络学习的是残差。

2.3 训练

本节对应train.py。

论文中相关训练参数:

  • 训练集:BSD400+ImageNet验证集中的400张+Waterloo Exploration Database (WED)4744
  • 块大小:35×35
  • 图像块数量:256 * 4000
  • 噪声水平:σ=15,25,50
  • 优化器:Adam
  • batch_size:256
  • lr:1e-3下降到1e-4
  • 数据增强:旋转和翻转

训练参数与论文基本一致,共训练100个epoch,50个epoch时学习率下降。

训练包含这些部分:参数设置、读取数据集、实例化模型、接续训练设置、损失函数、优化器、验证集指标计算、tqdm展示训练进度等

训练代码略,请查看专栏内前面文章的训练代码,基本都一致。

如果需要贴出来请评论告知。

训练过程验证集表现:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在第50个epoch降低学习率还是挺有道理。

调参小启示:模型较复杂,数据量别太小,pytorch训练去噪模型,学习率1e-4,基本上错不了。

2.4 测试

本节对应test.py。

  • 功能:计算测试集平均PSNR/SSIM;保存去噪后结果
  • 实现:遍历图像,加噪后作为模型输入,保存模型输出,并与加噪前干净图像之间计算PSNR和SSIM

测试代码略,请查看专栏内前面文章的测试代码,基本都一致。

结果展示:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
BSD68上的平均PSNR/SSIM:

Datasets σ IRCNN(paper) IRCNN(ours)
BSD681531.6331.61/0.8893
2529.15/0.825929.13
5026.1926.10/0.7162

三、总结与思考

由于IRCNN的模型比较简单,实现并不费力。网络是对称结构,每一层stride均为1,膨胀多少就padding多少。所以,IRCNN模型适用于任意图像大小,无论奇偶都不会报错。

完整代码和训练好的模型权重文件下载链接

图像去噪IRCNN的Pytorch极简复现代码,包含计算PSNR/SSIM以及训练好的模型文件,可以直接使用,训练自己的数据集


至此本文结束。

如果本文对你有所帮助,请点赞收藏,创作不易,感谢您的支持!

  • 19
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
U-Net是一种深度学习模型,最初用于生物医学图像分割,但它也可以应用于图像去噪任务。在PyTorch复现U-Net,你可以按照以下步骤操作: 1. **安装依赖**:首先确保已经安装了PyTorch及其相关的库,如torchvision。如果需要,可以运`pip install torch torchvision`. 2. **网络结构搭建**:创建一个U-Net模型的核心部分,它包括编码器(逐渐降低分辨率,提取特征)和解码器(逐步增加分辨率,恢复细节)。可以参考论文《Image Segmentation through Deep Learning》中的架构。 ```python import torch.nn as nn from torch.nn import Conv2d, MaxPool2d, UpSample class UNetBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): super(UNetBlock, self).__init__() self.encoder = nn.Sequential( Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), nn.ReLU(), Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding), nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding) ) def forward(self, x): skip_connection = x x = self.encoder(x) x = self.decoder(x) return torch.cat((x, skip_connection), dim=1) # 构建完整的U-Net模型 def create_unet(input_channels, num_classes): unet = nn.Sequential( nn.Conv2d(input_channels, 64, 3, padding=1), nn.MaxPool2d(2, 2), UNetBlock(64, 128), nn.MaxPool2d(2, 2), UNetBlock(128, 256), nn.MaxPool2d(2, 2), UNetBlock(256, 512), nn.MaxPool2d(2, 2), UNetBlock(512, 1024), nn.Upsample(scale_factor=2), UNetBlock(1024, 512), nn.Upsample(scale_factor=2), UNetBlock(512, 256), nn.Upsample(scale_factor=2), UNetBlock(256, 128), nn.Upsample(scale_factor=2), nn.Conv2d(128, num_classes, 1) ) return unet ``` 3. **训练和应用**:准备噪声图像数据、对应干净图像的数据集,然后定义损失函数(如MSE或SSIM)、优化器,并开始训练。训练完成后,对新的噪声图像前向传播以获得去噪后的结果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

十小大

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值