MIMO-UNet复现,DeepRFT复现及总结

本文总结了作者复现MIMO-UNet和其变体DeepRFT的过程,比较了两者在训练和验证集上的性能,发现DeepRFT在训练上有提升,但在验证集上的效果与MIMO-UNet相似,可能因超参数和训练轮数差异导致。
摘要由CSDN通过智能技术生成

最近复现了去模糊网络MIMO-UNet及变体DeepRFT,并以此文做一个总结:

复现MIMO-UNet部分:

1.通过上一篇博文,我们已经知道了MIMO-UNet网络的大致组成结构
2.通过源网络的main.py文件,我们可以知道该网络间隔100轮保存一次日志,训练的轮数为3000轮,这里为了节约算力,源网络以及后续的改进网络训练的轮数均设置为1000轮,batch_size为4,以在同一个维度上进行对比
复现结果的日志如下所示:
训练损失和PSNR如下所示:(训练时间1.061day)
在这里插入图片描述
在这里插入图片描述在这里插入图片描述
图中淡颜色的线为源网络中,每训练100轮,对训练结果在验证集上进行验证的结果,是一个折线图,深颜色的线是对折线图光滑0.6的结果smooth=0.6,可以看到FFT损失和Pixel损失一直下降并且达到了收敛,PSNR指标逐渐上升达到30.69并且还有继续上升的趋势(受限于算力)
**结论:**去模糊指标应该和作者描述的结果相差不太大,论文中MIMO-UNet的PSNR是31.73,这里的轮数没有那么多

复现DeepRFT部分:

DeepRFT的网络结构图如下所示
在这里插入图片描述
对比MIMO-UNet的网络结构图,不同之处在于换个残差模块,把原来的残差模块更改为带有傅里叶变换的Res FFT-Conv Block,这个也是这篇文章的创新点:更换一个模块
1.和MIMO-UNet的网络结构相同,由此我们也初步了解了DeepRFT的网络结构
2.通过源网络train.py文件的解析参数部分,可以知道该源网络在训练过程中设置的一些参数,训练3000轮,每20轮验证一次,同样这里为了节约算力,仅仅训练了670轮(太慢了),batch_size为4
复现结果和日志如下所示:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
图中淡颜色的线为源网络中,每训练20轮,对训练结果在验证集上进行验证的结果,是一个折线图,深颜色的线是对折线图光滑0.6的结果smooth=0.6,可以看到FFT损失和Pixel损失有一些震荡,PSNR训练指标逐渐上升达到32.37,验证指标30.67相差有点多(验证指标和MIMO-UNet很接近),训练的时候表现良好,验证就很小,当然这里的训练超参数和原文中是有一些不同的
**结论:**复现的过程应该是出现了过拟合,其主要原因可能是超参数和作者设置的不一样,算力达不到,但是不知道为啥验证部分结果和MIMO-UNet那么接近

根据DeepRFT提出的模块,在MIMO-UNet上进行修改:

进行这一步的主要原因是:假如当我有这个想法之后,我在别人的网络上更换自己的模块,会碰到哪些问题,更重要的是学习一个调参的过程,事实上,即使你有了很好的想法并成功,实施效果也不一定好,这是个大概率事件,整个过程也确实很大程度上依赖于经验,深度学习也确实有很多不可解释的现象。
1.整个过程就是把MIMO-UNet上的残差模块,更换为DeepRFT中提出的带有傅里叶的残差模块
更换过程如下:
更换的残差模块在layers.py中,如下部分是原来的残差模块

#定义残差模块,这里指的是定义一个单独的残差模块,后面还有把8个残差模块封装在一起的部分
class ResBlock(nn.Module):
    """
    ResBlock的构造函数。

    该构造函数初始化了一个残差块,包含两个基本卷积层。

    参数:
    - in_channel: 输入通道数
    - out_channel: 输出通道数

    返回值:
    - 无
    """
    def __init__(self, in_channel, out_channel):
        super(ResBlock, self).__init__()
        self.main = nn.Sequential(
            # 第一个基本卷积层,用于特征提取
            BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
            # 第二个基本卷积层,不使用ReLU激活函数,为特征映射
            BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
        )

    def forward(self, x):
        return self.main(x) + x

把上面的残差块,更换为带有傅里叶变换的残差块,在layers.py文件中,增加如下部分代码(直接从DeepRFT复制粘贴过来,这里使用的是ResBlock_fft_bench,把模块中输入和输出部分的变量和MIMO-UNet保持一致如下)

class ResBlock_fft_bench(nn.Module):
    """
    使用FFT实现的残差块,用于benchmark测试。

    参数:
    - in_channle: 输入通道数
    - out_channle: 输出通道数
    - norm: 正则化方式,默认为'backward',可选'ortho'
    """
    def __init__(self, in_channle, out_channle, norm='backward'): # 'ortho'
        super(ResBlock_fft_bench, self).__init__()
        # 定义基于传统卷积的主路径
        self.main = nn.Sequential(
            BasicConv(in_channle, out_channle, kernel_size=3, stride=1, relu=True),
            BasicConv(out_channle, out_channle, kernel_size=3, stride=1, relu=False)
        )
        # 定义基于FFT的路径
        self.main_fft = nn.Sequential(
            BasicConv(out_channle*2, out_channle*2, kernel_size=1, stride=1, relu=True),
            BasicConv(out_channle*2, out_channle*2, kernel_size=1, stride=1, relu=False)
        )
        self.dim = out_channle
        self.norm = norm
    def forward(self, x):
        """
        前向传播函数。

        参数:
        - x: 输入特征图

        返回:
        - 加速的传统卷积路径和FFT路径的和
        """
        _, _, H, W = x.shape
        dim = 1
        # 使用FFT对输入进行处理
        y = torch.fft.rfft2(x, norm=self.norm)
        y_imag = y.imag
        y_real = y.real
        # 合并实部和虚部以供FFT处理使用
        y_f = torch.cat([y_real, y_imag], dim=dim)
        # 在FFT路径上应用卷积
        y = self.main_fft(y_f)
        # 分割实部和虚部,为逆FFT做准备
        y_real, y_imag = torch.chunk(y, 2, dim=dim)
        # 重建复数信号
        y = torch.complex(y_real, y_imag)
        # 应用逆FFT恢复到时域
        y = torch.fft.irfft2(y, s=(H, W), norm=self.norm)
        # 输出为传统卷积路径、输入和FFT路径的和
        return self.main(x) + x + y

把原来搭建的网络MIMOUNet中使用ResBlock类的部分更换为ResBlock_fft_bench,直接使用F7查找项目中使用这部分的代码,实现更换即可!
复现结果和日志如下所示:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
可以看到FFT损失和Pixel损失一直下降并且达到了收敛,PSNR指标逐渐上升达到30.6,有一些浮动
**结论:**这里的结果和MIMO-UNet以及DeepRFT的验证部分都比较接近,但是确实是对PSNR曲线有一些影响,影响并没有那么大…,,,可能的原因是并没有训练那么多轮次?超参数太重要,没有和源代码保持一致??

**最后的总结:**目前的结论是使用带有FFT的残差模块的网络DeepFRT在训练数据上,PSNR会有明显的提高,但是在验证集和MIMO-UNet和更换为残差模块的MIMO-UNet的区别不大,,,至于这个原因也许是超参数不一样吧,,,,

  • 22
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
MIMO-UNET是一种用于图像去模糊的算法,它结合了MIMO(Multiple-Input Multiple-Output)和UNET两种技术。下面是对MIMO-UNET去模糊算法的介绍: MIMO-UNET算法是基于深度学习的图像去模糊方法,它通过使用多个输入和多个输出来提高去模糊的效果。传统的UNET算法只使用单个输入和单个输出,而MIMO-UNET则引入了多个输入和多个输出,以更好地捕捉图像中的细节和纹理信息。 MIMO-UNET算法的核心思想是将图像的模糊处理问题转化为一个端到端的深度学习任务。它使用编码器-解码器结构,其中编码器负责提取图像的特征,解码器则负责将特征映射回原始图像空间。通过多个输入和多个输出,MIMO-UNET可以同时处理多个模糊程度的图像,并生成对应的清晰图像。 MIMO-UNET算法的训练过程包括两个阶段:训练编码器-解码器网络和训练多输入多输出网络。在第一个阶段,使用已知的清晰图像和对应的模糊图像对编码器-解码器网络进行训练,以学习图像的特征表示和重建能力。在第二个阶段,使用多个模糊程度的图像对多输入多输出网络进行训练,以学习不同程度模糊图像的去模糊映射。 MIMO-UNET算法在图像去模糊任务中具有较好的性能,它可以有效地恢复图像的细节和纹理信息。同时,MIMO-UNET还可以处理多个模糊程度的图像,适用于不同场景下的图像去模糊需求。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值