MIMO-UNet的网络结构分析

MIMO-UNet:一种2021年发表在ICCV上的去模糊网络
网络整体结构如下
在这里插入图片描述
组成模块如下
在这里插入图片描述
根据整体的网络框架,网络主要有以下部分组成:

1. EBlock 和 DBlock
这两个类分别代表编码器(EBlock)和解码器(DBlock)中的基本模块,它们内部均使用一系列的 ResBlock(残差块)。
输入:对于每个 ResBlock,输入是一组多通道的特征图(张量),通常其形状为 (batch_size, channels, height, width)。
输出:每个 ResBlock 的输出同样是多通道的特征图,形状与输入相同。EBlock 或 DBlock 作为一个整体,其输出是由内部多个 ResBlock 依次处理后的最终特征图。
连接方式:在 EBlock 或 DBlock 内部,多个 ResBlock 通过 nn.Sequential 连接,形成一个顺序执行的结构。特征图从第一个 ResBlock 输入,逐个经过所有 ResBlock 处理后得到最终输出。
2. AFF
AFF (Affine Flow Field) 模块用于实现特征融合。
输入:三个张量 x1, x2, x4,它们代表不同层次或尺度的特征图。
输出:经过两个基本卷积层处理后的融合特征图。
连接方式:首先将输入的三个张量在通道维度(dim=1)上进行连接,然后通过包含两个基本卷积层的 nn.Sequential 结构进行处理。
3. SCM
SCM 模块是一个特定的卷积模块,包含多个基本卷积层。
输入:单个特征图。
输出:经过该模块处理后的特征图。
连接方式:直接将输入特征图送入由多个基本卷积层组成的 nn.Sequential 结构进行处理。
4. FAM
FAM (Feature Attention Module) 模块用于实现特征注意力机制。
输入:两个特征图 x1 和 x2。
输出:经过特征注意力机制处理后的输出特征图。
连接方式:首先将输入特征图进行元素乘积融合,然后将融合结果与其中一个输入特征图通过一个基本卷积层(merge)结合并相加,得到最终输出。
5. MIMOUNet 和 MIMOUNetPlus
这两个类代表整个 MIMO-UNet 和 MIMO-UNetPlus 模型架构,它们内部包含了多个上述组件以及其他辅助模块。
输入:一个三维图像张量 x,形状为 (batch_size, channels, height, width)。
输出:一个包含多个中间输出的列表,每个输出是不同尺度下的处理结果(通常是图像)。
连接方式:
编码器(Encoder):将输入图像经过一系列 EBlock 进行下采样和特征提取,得到不同层次的特征图。
特征提取模块(feat_extract):包含多个基本卷积层,用于进一步处理编码器的输出或解码器的输入。
解码器(Decoder):通过一系列 DBlock 进行上采样,同时融合来自编码器的特征图,恢复图像细节。
中间卷积层(Convs):用于调整特征图的通道数,以便与解码器输出进行拼接。
输出卷积层(ConvsOut):生成最终的输出图像。
自适应特征融合模块(AFFs):融合不同尺度的特征图。
特征注意力模块(FAM1, FAM2, SCM1, SCM2):在特定阶段应用特征注意力机制或特殊卷积处理。
整体流程如下:
输入图像首先通过编码器进行下采样和特征提取。
编码器的输出通过特征提取模块进行进一步处理。
解码器对编码器输出进行上采样,并在适当位置融合来自编码器的特征图。
中间卷积层调整特征图通道数以适应解码器输出。
输出卷积层生成最终的图像输出。
自适应特征融合模块和特征注意力模块在特定步骤中插入,以增强特征交互和信息传递。
基础层部分代码

import torch
import torch.nn as nn


class BasicConv(nn.Module): #定义基础卷积部分,定义紫色模块部分,对于编程来说定义好每一个模块,再进行调用
    def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
        """
        初始化BasicConv类的实例。
        :param in_channel: 输入通道数
        :param out_channel: 输出通道数
        :param kernel_size: 卷积核大小
        :param stride: 卷积步长
        :param bias: 是否使用偏置,默认为True
        :param norm: 是否使用归一化,默认为False
        :param relu: 是否使用ReLU激活函数,默认为True
        :param transpose: 是否使用转置卷积,默认为False
        """

        super(BasicConv, self).__init__()
        # 如果同时设置了偏置和归一化,则关闭偏置
        if bias and norm:
            bias = False

        # 计算padding值以保持输入输出尺寸一致,输入和输出大小不改变
        padding = kernel_size // 2
        layers = list()
        # 如果是转置卷积,则调整padding值并添加到layers中
        if transpose:
            padding = kernel_size // 2 - 1
            layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        else:
            # 否则,添加普通的卷积层到layers中
            layers.append(
                nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
        # 如果设置了归一化,则添加BatchNorm2d层
        if norm:
            layers.append(nn.BatchNorm2d(out_channel))
        # 如果设置了ReLU激活函数,则添加ReLU层
        if relu:
            layers.append(nn.ReLU(inplace=True))
        # 将所有层封装成一个Sequential容器,可以直接调用main
        self.main = nn.Sequential(*layers)


    def forward(self, x):
        """
        实现前向传播过程。

        参数:
        self -- 对象的自引用。
        x -- 输入数据。

        返回值:
        返回经过模型主要结构处理后的结果。
        """
        return self.main(x)  # 调用main模块进行主要的前向传播处理

#定义残差模块,这里指的是定义一个单独的残差模块,后面还有把8个残差模块封装在一起的部分
class ResBlock(nn.Module):
    """
    ResBlock的构造函数。

    该构造函数初始化了一个残差块,包含两个基本卷积层。

    参数:
    - in_channel: 输入通道数
    - out_channel: 输出通道数

    返回值:
    - 无
    """
    def __init__(self, in_channel, out_channel):
        super(ResBlock, self).__init__()
        self.main = nn.Sequential(
            # 第一个基本卷积层,用于特征提取
            BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
            # 第二个基本卷积层,不使用ReLU激活函数,为特征映射
            BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
        )

    def forward(self, x):
        return self.main(x) + x

网络实现部分代码:

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import *

# EBlock类:构造包含多个ResBlock层的序列模型
class EBlock(nn.Module):
    """
    EBlock类的构造函数。

    用于创建一个包含多个ResBlock层的序列模型。

    参数:
    - out_channel: int, 输出通道数。
    - num_res: int, ResBlock的数量,默认为8。

    返回:
    - 无
    """

    def __init__(self, out_channel, num_res=8):
        super(EBlock, self).__init__()

        # 通过列表推导式创建num_res个ResBlock层,并将它们封装到nn.Sequential中
        layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        """
        实现前向传播过程。

        参数:
        - x输入数据,可以是单个数据样本或数据批量。

        返回值:
        - 经过所有层处理后的输出结果。
        """
        return self.layers(x)  # 通过self.layers中的层对输入x进行处理


class DBlock(nn.Module):
    """
    DBlock类:定义了一个包含多个ResBlock层的深度块。

    参数:
    - channel: int, 每个ResBlock层的通道数。
    - num_res: int, DBlock中ResBlock层的数量,默认为8。

    返回值:
    - 无
    """

    def __init__(self, channel, num_res=8):
        super(DBlock, self).__init__()

        # 通过循环构造num_res个ResBlock层,并将它们串联起来
        layers = [ResBlock(channel, channel) for _ in range(num_res)]
        self.layers = nn.Sequential(*layers)
        # self.layers即为包含num_res个ResBlock层的序列,可用于模型的前向传播。

    def forward(self, x):
        """
        实现前向传播过程。

        参数:
        - x输入数据,可以是单个数据样本或数据批量。

        返回值:
        - 经过所有层处理后的输出结果。
        """
        return self.layers(x)  # 通过self.layers中的层对输入x进行处理


# AFF类:初始化AFF(Affine Flow Field)模块
class AFF(nn.Module):
    """
    初始化AFF(Affine Flow Field)模块。

    该模块用于创建一个包含两个基本卷积层的序列,第一个卷积层用于转换输入通道到输出通道,第二个卷积层保持输出通道不变。

    参数:
    - in_channel(int): 输入通道数。
    - out_channel(int): 输出通道数。

    属性:
    - conv(nn.Sequential): 包含两个基本卷积层的序列。
    """
    def __init__(self, in_channel, out_channel):
        super(AFF, self).__init__()
        # 创建一个包含两个基本卷积层的序列:
        # 第一个卷积层用于从in_channel转换到out_channel,
        # 第二个卷积层保持out_channel不变。
        self.conv = nn.Sequential(
            BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
            BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
        )

    def forward(self, x1, x2, x4):
        """
        实现前向传播过程。

        参数:
        - x1, x2, x4: 输入的三个张量,它们将在维度1上连接。

        返回:
        - 经过self.conv卷积操作处理后的张量。
        """
        # 在维度1上连接输入的三个张量
        x = torch.cat([x1, x2, x4], dim=1)
        # 通过卷积层处理连接后的张量
        return self.conv(x)

class SCM(nn.Module):
    """
    SCM类:实现了一个特定的卷积模块。

    参数:
    - out_plane: 输出通道数。

    方法:
    - forward: 前向传播方法。
    """
    def __init__(self, out_plane):
        """
        初始化SCM模块。

        参数:
        - out_plane: 输出通道数。
        """
        super(SCM, self).__init__()
        # 定义主要的卷积序列,包括四个基本卷积层
        self.main = nn.Sequential(
            BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
            BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
            BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
            BasicConv(out_plane // 2, out_plane-3, kernel_size=1, stride=1, relu=True)
        )

        # 定义一个额外的卷积层,用于最后的特征融合
        self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False)

    def forward(self, x):
        """
        定义前向传播路径。

        参数:
        - x: 输入特征数据。

        返回:
        - 经过SCM模块处理后的特征数据。
        """
        # 将输入特征与主卷积序列的输出在通道维度上连接
        x = torch.cat([x, self.main(x)], dim=1)
        # 使用额外的卷积层处理连接后的特征数据
        return self.conv(x)

# FAM类:实现特征注意力模块
class FAM(nn.Module):
    """
    FAM(Feature Attention Module)类,继承自nn.Module,用于实现特征注意力模块。

    参数:
    - channel: 输入通道数

    方法:
    - forward: 前向传播方法。
    """
    def __init__(self, channel):
        """
        初始化FAM模块。

        参数:
        - channel: 输入特征的通道数。
        """
        super(FAM, self).__init__()
        self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False)

    def forward(self, x1, x2):
        """
        定义前向传播操作。

        参数:
        - x1: 输入特征1
        - x2: 输入特征2

        返回:
        - out: 经过特征注意力机制处理后的输出特征。
        """
        x = x1 * x2  # 通过元素乘积融合两个输入特征
        out = x1 + self.merge(x)  # 将融合后的特征与原始特征x1通过BasicConv模块结合,并相加得到最终输出
        return out


# MIMOUNet类:定义MIMOUNet模型
class MIMOUNet(nn.Module):
    """
    MIMOUNet模型的类定义。

    该模型用于图像处理任务,具体结构包括编码器、特征提取模块、解码器等部分。
    参数:
    - num_res: int, 默认为8, 表示残差块的数量。
    """
    def __init__(self, num_res=8):
        super(MIMOUNet, self).__init__()

        base_channel = 32

        # 初始化编码器模块列表
        self.Encoder = nn.ModuleList([
            EBlock(base_channel, num_res),
            EBlock(base_channel*2, num_res),
            EBlock(base_channel*4, num_res),
        ])

        # 初始化特征提取模块列表
        self.feat_extract = nn.ModuleList([
            BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
            BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
        ])

        # 初始化解码器模块列表
        self.Decoder = nn.ModuleList([
            DBlock(base_channel * 4, num_res),
            DBlock(base_channel * 2, num_res),
            DBlock(base_channel, num_res)
        ])

        # 初始化中间卷积层模块列表
        self.Convs = nn.ModuleList([
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
            BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
        ])

        # 初始化输出卷积层模块列表
        self.ConvsOut = nn.ModuleList(
            [
                BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
                BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
            ]
        )

        # 初始化自适应特征融合(Adaptive Feature Fusion)模块列表
        self.AFFs = nn.ModuleList([
            AFF(base_channel * 7, base_channel*1),
            AFF(base_channel * 7, base_channel*2)
        ])

        # 初始化特征注意力模块(Feature Attention Module)
        self.FAM1 = FAM(base_channel * 4)
        self.SCM1 = SCM(base_channel * 4)
        self.FAM2 = FAM(base_channel * 2)
        self.SCM2 = SCM(base_channel * 2)

    def forward(self, x):
        """
        前向传播函数。

        参数:
        - x: 输入图像的张量。

        返回值:
        - outputs: 一个包含多个中间输出的列表。
        """
        # 缩小输入图像尺寸
        x_2 = F.interpolate(x, scale_factor=0.5)
        x_4 = F.interpolate(x_2, scale_factor=0.5)
        # 通过SCM模块处理缩小后的图像
        z2 = self.SCM2(x_2)
        z4 = self.SCM1(x_4)

        outputs = list()

        # 通过特征提取(卷积层)和编码器进行处理,并保存中间结果
        x_ = self.feat_extract[0](x)
        res1 = self.Encoder[0](x_)

        z = self.feat_extract[1](res1)
        z = self.FAM2(z, z2)
        res2 = self.Encoder[1](z)

        z = self.feat_extract[2](res2)
        z = self.FAM1(z, z4)
        z = self.Encoder[2](z)

        # 进行特征的上采样和融合
        z12 = F.interpolate(res1, scale_factor=0.5)
        z21 = F.interpolate(res2, scale_factor=2)
        z42 = F.interpolate(z, scale_factor=2)
        z41 = F.interpolate(z42, scale_factor=2)

        res2 = self.AFFs[1](z12, res2, z42)
        res1 = self.AFFs[0](res1, z21, z41)

        # 通过解码器和输出卷积层得到最终输出
        z = self.Decoder[0](z)
        z_ = self.ConvsOut[0](z)
        z = self.feat_extract[3](z)
        outputs.append(z_+x_4)

        z = torch.cat([z, res2], dim=1)
        z = self.Convs[0](z)
        z = self.Decoder[1](z)
        z_ = self.ConvsOut[1](z)
        z = self.feat_extract[4](z)
        outputs.append(z_+x_2)

        z = torch.cat([z, res1], dim=1)
        z = self.Convs[1](z)
        z = self.Decoder[2](z)
        z = self.feat_extract[5](z)
        outputs.append(z+x)

        return outputs



class MIMOUNetPlus(nn.Module):
    """
    MIMOUNetPlus网络结构,用于图像处理任务。

    参数:
    - num_res (int): 用于编码器和解码器中的残差块的数量。

    返回:
    - 无
    """
    def __init__(self, num_res = 20):
        super(MIMOUNetPlus, self).__init__()

        # 初始化基本通道数和编码器模块列表
        base_channel = 32
        self.Encoder = nn.ModuleList([
            EBlock(base_channel, num_res),
            EBlock(base_channel*2, num_res),
            EBlock(base_channel*4, num_res),
        ])

        # 初始化特征提取模块列表
        self.feat_extract = nn.ModuleList([
            BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
            BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
            BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
            BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
        ])

        # 初始化解码器模块列表
        self.Decoder = nn.ModuleList([
            DBlock(base_channel * 4, num_res),
            DBlock(base_channel * 2, num_res),
            DBlock(base_channel, num_res)
        ])

        # 初始化中间卷积模块列表
        self.Convs = nn.ModuleList([
            BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
            BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
        ])

        # 初始化输出卷积模块列表
        self.ConvsOut = nn.ModuleList(
            [
                BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
                BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
            ]
        )

        # 初始化注意力融合模块列表
        self.AFFs = nn.ModuleList([
            AFF(base_channel * 7, base_channel*1),
            AFF(base_channel * 7, base_channel*2)
        ])

        # 初始化特征注意力模块
        self.FAM1 = FAM(base_channel * 4)
        self.SCM1 = SCM(base_channel * 4)
        self.FAM2 = FAM(base_channel * 2)
        self.SCM2 = SCM(base_channel * 2)

        # 初始化dropout层
        self.drop1 = nn.Dropout2d(0.1)
        self.drop2 = nn.Dropout2d(0.1)

    def forward(self, x):
        """
        前向传播函数。

        参数:
        - x (Tensor): 输入图像的张量。

        返回:
        - outputs (list of Tensor): 输出图像的列表,每个元素为不同尺度下的处理结果。
        """
        # 缩小输入图像尺度
        x_2 = F.interpolate(x, scale_factor=0.5)
        x_4 = F.interpolate(x_2, scale_factor=0.5)

        # 通过SCM模块处理不同尺度的输入
        z2 = self.SCM2(x_2)
        z4 = self.SCM1(x_4)

        outputs = list()

        # 通过特征提取和编码器模块处理输入图像
        x_ = self.feat_extract[0](x)
        res1 = self.Encoder[0](x_)

        z = self.feat_extract[1](res1)
        z = self.FAM2(z, z2)
        res2 = self.Encoder[1](z)

        z = self.feat_extract[2](res2)
        z = self.FAM1(z, z4)
        z = self.Encoder[2](z)

        # 进行特征融合和尺度跳接
        z12 = F.interpolate(res1, scale_factor=0.5)
        z21 = F.interpolate(res2, scale_factor=2)
        z42 = F.interpolate(z, scale_factor=2)
        z41 = F.interpolate(z42, scale_factor=2)

        res2 = self.AFFs[1](z12, res2, z42)
        res1 = self.AFFs[0](res1, z21, z41)

        res2 = self.drop2(res2)
        res1 = self.drop1(res1)

        # 通过解码器和输出卷积模块得到最终输出
        z = self.Decoder[0](z)
        z_ = self.ConvsOut[0](z)
        z = self.feat_extract[3](z)
        outputs.append(z_+x_4)

        z = torch.cat([z, res2], dim=1)
        z = self.Convs[0](z)
        z = self.Decoder[1](z)
        z_ = self.ConvsOut[1](z)
        z = self.feat_extract[4](z)
        outputs.append(z_+x_2)

        z = torch.cat([z, res1], dim=1)
        z = self.Convs[1](z)
        z = self.Decoder[2](z)
        z = self.feat_extract[5](z)
        outputs.append(z+x)

        return outputs


def build_net(model_name):
    """
    根据提供的模型名称构建相应的网络模型。

    参数:
    - model_name: 字符串,指定要构建的模型名称。

    返回值:
    - 返回一个初始化好的指定模型实例。

    异常:
    - ModelError: 如果模型名称不正确,则抛出此自定义异常。
    """

    # 定义一个自定义的模型错误异常类
    class ModelError(Exception):
        def __init__(self, msg):
            # 初始化异常信息
            self.msg = msg

        def __str__(self):
            # 返回异常信息的字符串表示
            return self.msg

    # 根据模型名称选择对应的模型类进行实例化
    if model_name == "MIMO-UNetPlus":
        return MIMOUNetPlus()
    elif model_name == "MIMO-UNet":
        return MIMOUNet()

    # 如果模型名称不符合预期,抛出ModelError异常
    raise ModelError('Wrong Model!\nYou should choose MIMO-UNetPlus or MIMO-UNet.')


  • 10
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
MIMO-UNET是一种用于图像去模糊的算法,它结合了MIMO(Multiple-Input Multiple-Output)和UNET两种技术。下面是对MIMO-UNET去模糊算法的介绍: MIMO-UNET算法是基于深度学习的图像去模糊方法,它通过使用多个输入和多个输出来提高去模糊的效果。传统的UNET算法只使用单个输入和单个输出,而MIMO-UNET则引入了多个输入和多个输出,以更好地捕捉图像中的细节和纹理信息。 MIMO-UNET算法的核心思想是将图像的模糊处理问题转化为一个端到端的深度学习任务。它使用编码器-解码器结构,其中编码器负责提取图像的特征,解码器则负责将特征映射回原始图像空间。通过多个输入和多个输出,MIMO-UNET可以同时处理多个模糊程度的图像,并生成对应的清晰图像。 MIMO-UNET算法的训练过程包括两个阶段:训练编码器-解码器网络和训练多输入多输出网络。在第一个阶段,使用已知的清晰图像和对应的模糊图像对编码器-解码器网络进行训练,以学习图像的特征表示和重建能力。在第二个阶段,使用多个模糊程度的图像对多输入多输出网络进行训练,以学习不同程度模糊图像的去模糊映射。 MIMO-UNET算法在图像去模糊任务中具有较好的性能,它可以有效地恢复图像的细节和纹理信息。同时,MIMO-UNET还可以处理多个模糊程度的图像,适用于不同场景下的图像去模糊需求。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值