傅里叶卷积Fourier Convolutions

傅里叶卷积Fourier Convolutions

对输入tensor进行FFT,然后提取出实部和虚部,对实部和虚部进行卷积计算,再还原为实部和虚部,还原为tensor。
中间可以对频率域使用SE,对不同的频率进行重标定。

class FourierUnit(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
                 spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
        # bn_layer not used
        super(FourierUnit, self).__init__()
        self.groups = groups

        self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
                                          out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
        self.bn = torch.nn.BatchNorm2d(out_channels * 2)
        self.relu = torch.nn.ReLU(inplace=True)

        # squeeze and excitation block
        self.use_se = use_se
        if use_se:
            if se_kwargs is None:
                se_kwargs = {}
            self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)

        self.spatial_scale_factor = spatial_scale_factor
        self.spatial_scale_mode = spatial_scale_mode
        self.spectral_pos_encoding = spectral_pos_encoding
        self.ffc3d = ffc3d
        self.fft_norm = fft_norm

    def forward(self, x):
        batch = x.shape[0]  #(b,c,h,w)

        if self.spatial_scale_factor is not None:
            orig_size = x.shape[-2:]
            x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)

        # (batch, c, h, w/2+1, 2)
        fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
        #如果是三维fft,需要计算的就是后三个维度,否则就是后两个维度
        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)  # (b,c,h,w/2+1)
        #实数信号fft,只保留前N/2+1(N为偶数,奇数为(N+1)/2)个点就能完全描述信号的频谱信息,另一半通过共轭对称性推导出来。
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1) # (b,c,h,w/2+1,2)
        #将实部和虚部堆叠在一起
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (b, c, 2, h, w/2+1)
        ffted = ffted.view((batch, -1,) + ffted.size()[3:]) # (b,c*2,h,w/2+1)

        if self.spectral_pos_encoding:
            height, width = ffted.shape[-2:]
            coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) # (b,1,h,w/2+1)
            coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) # (b,1,h,w/2+1)
            ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) # (b,c*2+2,h,w/2+1)

        if self.use_se:
            ffted = self.se(ffted)

        ffted = self.conv_layer(ffted)  # (b, c*2, h, w/2+1)
        #对频率部分进行卷积计算
        ffted = self.relu(self.bn(ffted))

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (b,c,h, w/2+1, 2)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])  # (b,c,h,w/2+1)
        #合成复数Tensor

        ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]  #(h,w)
        output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) #(b,c,h,w)
        #傅里叶逆变换,还原图像

        if self.spatial_scale_factor is not None:
            output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)

        return output

参考:
https://github.com/advimman/lama

  • 7
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值