SpectralTransform

SpectralTransform,频谱变换,对输入进行傅里叶变换和局部傅里叶变换,然后将两次变换结果和输入相加,再经过卷积融合。其中局部傅里叶变换,取前1/4个通道,分别在宽高维度拆分为两部分,通过重组增加通道数,再次进行傅里叶变换。

通过将空间维度分拆并重组,将空间分割后的特征映射到更多的通道上,可以使得网络能够更加专注于局部区域的特征提取,在频域处理之后再次进行空域的特征组合,从而实现频域和空域特征的有效融合。增加通道维度通常会增加网络的表达能力,因为它提供了更多的特征图用于学习不同的特征。

class SpectralTransform(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride=1,
        groups=1,
        enable_lfu=True,
        **fu_kwargs,
    ):
        # bn_layer not used
        super(SpectralTransform, self).__init__()
        self.enable_lfu = enable_lfu
        if stride == 2:
            self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        else:
            self.downsample = nn.Identity()

        self.stride = stride
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
            ),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(inplace=True),
        )
        self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
        if self.enable_lfu:
            self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups)
        self.conv2 = torch.nn.Conv2d(
            out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
        )

    def forward(self, x):
        x = self.downsample(x)  # (b,c,h,w)
        x = self.conv1(x) # (b,c/2,h,w)
        output = self.fu(x)  # (b,c/2,h,w)  #全局傅里叶变换

        if self.enable_lfu: #局部傅里叶变换
            n, c, h, w = x.shape  # (b,c/2,h,w)
            split_no = 2
            split_s = h // split_no  #h==w
            xs = torch.cat(
                torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
            ).contiguous()  #(b,c//8,h/2,w) => (b,c//4,h/2,w)
            xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
            # (b,c//4,h/2,w/2) => (b,c/2,h/2,w/2)
            xs = self.lfu(xs) # (b,c/2,h/2,w/2)
            xs = xs.repeat(1, 1, split_no, split_no).contiguous()   # (b,c/2,h,w)
        else:
            xs = 0

        output = self.conv2(x + output + xs)  # (b,c/2,h,w) =>(b,c,h,w)

        return output
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值