SpectralTransform,频谱变换,对输入进行傅里叶变换和局部傅里叶变换,然后将两次变换结果和输入相加,再经过卷积融合。其中局部傅里叶变换,取前1/4个通道,分别在宽高维度拆分为两部分,通过重组增加通道数,再次进行傅里叶变换。
通过将空间维度分拆并重组,将空间分割后的特征映射到更多的通道上,可以使得网络能够更加专注于局部区域的特征提取,在频域处理之后再次进行空域的特征组合,从而实现频域和空域特征的有效融合。增加通道维度通常会增加网络的表达能力,因为它提供了更多的特征图用于学习不同的特征。
class SpectralTransform(nn.Module):
def __init__(
self,
in_channels,
out_channels,
stride=1,
groups=1,
enable_lfu=True,
**fu_kwargs,
):
# bn_layer not used
super(SpectralTransform, self).__init__()
self.enable_lfu = enable_lfu
if stride == 2:
self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
else:
self.downsample = nn.Identity()
self.stride = stride
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU(inplace=True),
)
self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
if self.enable_lfu:
self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups)
self.conv2 = torch.nn.Conv2d(
out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
)
def forward(self, x):
x = self.downsample(x) # (b,c,h,w)
x = self.conv1(x) # (b,c/2,h,w)
output = self.fu(x) # (b,c/2,h,w) #全局傅里叶变换
if self.enable_lfu: #局部傅里叶变换
n, c, h, w = x.shape # (b,c/2,h,w)
split_no = 2
split_s = h // split_no #h==w
xs = torch.cat(
torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
).contiguous() #(b,c//8,h/2,w) => (b,c//4,h/2,w)
xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
# (b,c//4,h/2,w/2) => (b,c/2,h/2,w/2)
xs = self.lfu(xs) # (b,c/2,h/2,w/2)
xs = xs.repeat(1, 1, split_no, split_no).contiguous() # (b,c/2,h,w)
else:
xs = 0
output = self.conv2(x + output + xs) # (b,c/2,h,w) =>(b,c,h,w)
return output