图像超分领域的模块FMB代码分享

上文已经介绍了图像超分领域的模块FMB的作用,本文讲这个模块分享一下。

SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution( ECCV 2024 )
https://openaccess.thecvf.com/content/CVPR2024W/NTIRE/papers/Ren_The_Ninth_NTIRE_2024_Efficient_Super-Resolution_Challenge_Report_CVPRW_2024_paper.pdf

import torch
import torch.nn as nn
import torch.nn.functional as F

"""
DMlp 类
DMlp(动态MLP)类实现了一个简单的多层感知机,使用卷积层:

初始化 (__init__ 方法):

dim: 输入特征维度。
growth_rate: 增加隐藏层维度的因子。
hidden_dim: 计算为 dim * growth_rate。
conv_0: 一个顺序层,包含:
深度卷积(groups=dim)。
点卷积。
act: 激活函数(GELU)。
conv_1: 点卷积,将维度投射回 dim。
前向传播 (forward 方法):

应用 conv_0,然后激活,再应用 conv_1。
PCFN 类
PCFN(部分卷积的前馈网络)类:

初始化:

dim: 输入特征维度。
growth_rate: 增加隐藏层维度的因子。
p_rate: 部分卷积的隐藏维度比例。
hidden_dim: 计算为 dim * growth_rate。
p_dim: 计算为 hidden_dim * p_rate。
conv_0: 点卷积。
conv_1: 对隐藏维度的一部分应用卷积。
act: 激活函数(GELU)。
conv_2: 点卷积,将维度投射回 dim。
前向传播:

对训练和评估模式有不同的行为。
在训练期间,将 conv_0 的输出分成两部分,对第一部分应用 conv_1,然后连接,再应用 conv_2。
在评估期间,直接对特征图的一部分应用 conv_1。
SMFA 类
SMFA(自调制特征聚合)模块:

初始化:

dim: 输入特征维度。
多个卷积层用于线性变换。
lde: 一个 DMlp 类的实例。
dw_conv: 深度卷积。
gelu: 激活函数。
down_scale: 下采样的缩放因子。
alpha 和 belt: 可学习的调制参数。
前向传播:

使用 linear_0 将输入分成两部分。
对第二部分(x)的下采样版本应用深度卷积。
计算特征图的方差,并与下采样特征图结合。
将调制后的特征图与 DMlp 的输出结合。
FMB 类
FMB(特征调制块):

初始化:

dim: 输入特征维度。
ffn_scale: 前馈网络的缩放因子。
SMFA 和 PCFN 的实例。
前向传播:

对归一化的输入应用 SMFA,将结果加到输入上。
对归一化的输入应用 PCFN,将结果加到输入上。
"""
class DMlp(nn.Module):
    def __init__(self, dim, growth_rate=2.0):
        super().__init__()
        hidden_dim = int(dim * growth_rate)
        self.conv_0 = nn.Sequential(
            nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
            nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
        )
        self.act = nn.GELU()
        self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)

    def forward(self, x):
        x = self.conv_0(x)
        x = self.act(x)
        x = self.conv_1(x)
        return x

# partial convolution-based feed-forward network
class PCFN(nn.Module):
    def __init__(self, dim, growth_rate=2.0, p_rate=0.25):
        super().__init__()
        hidden_dim = int(dim * growth_rate)
        p_dim = int(hidden_dim * p_rate)
        self.conv_0 = nn.Conv2d(dim, hidden_dim, 1, 1, 0)
        self.conv_1 = nn.Conv2d(p_dim, p_dim, 3, 1, 1)

        self.act = nn.GELU()
        self.conv_2 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)

        self.p_dim = p_dim
        self.hidden_dim = hidden_dim

    def forward(self, x):
        if self.training:
            x = self.act(self.conv_0(x))
            x1, x2 = torch.split(x, [self.p_dim, self.hidden_dim - self.p_dim], dim=1)
            x1 = self.act(self.conv_1(x1))
            x = self.conv_2(torch.cat([x1, x2], dim=1))
        else:
            x = self.act(self.conv_0(x))
            x[:, :self.p_dim, :, :] = self.act(self.conv_1(x[:, :self.p_dim, :, :]))
            x = self.conv_2(x)
        return x

#self-modulation feature aggregation (SMFA) module
class SMFA(nn.Module):
    def __init__(self, dim=36):
        super(SMFA, self).__init__()
        self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
        self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
        self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)

        self.lde = DMlp(dim, 2)

        self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)

        self.gelu = nn.GELU()
        self.down_scale = 8

        self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
        self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))

    def forward(self, f):
        _, _, h, w = f.shape
        y, x = self.linear_0(f).chunk(2, dim=1)
        x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
        x_v = torch.var(x, dim=(-2, -1), keepdim=True)
        x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
                                mode='nearest')
        y_d = self.lde(y)
        return self.linear_2(x_l + y_d)

#Feature modulation block(FMB)
class FMB(nn.Module):
    def __init__(self, dim, ffn_scale=2.0):
        super().__init__()

        self.smfa = SMFA(dim)
        self.pcfn = PCFN(dim, ffn_scale)

    def forward(self, x):
        x = self.smfa(F.normalize(x)) + x
        x = self.pcfn(F.normalize(x)) + x
        return x


if __name__ == '__main__':
    # 定义输入张量的形状 (batch_size, channels, height, width)
    input_shape = (1, 36, 64, 64)

    # 创建一个随机张量作为输入
    input_tensor = torch.randn(input_shape)

    # 实例化FMB类
    fmb = FMB(dim=36)

    # 将输入张量传入FMB实例
    output_tensor = fmb(input_tensor)

    # 打印输入和输出的形状
    print(f"输入形状: {input_tensor.shape}")
    print(f"输出形状: {output_tensor.shape}")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值