FFC–快速傅里叶卷积
- 全局感受野:FFC通过应用傅立叶变换,能够在卷积操作的早期阶段就获取到图像的全局特征,这有助于在处理需要长距离依赖的任务时,如图像修复、图像超分辨率等,获得更好的效果。
- 参数效率:由于FFC可以利用整个输入图像的信息,它能够更有效地使用参数,这意味着在网络的早期层就可以进行更复杂的推理和生成,而无需等待信息通过深层网络传递。
- 尺度等变性:FFC通过在所有频率上共享相同的卷积核,促使模型倾向于尺度等变性,这意味着它在处理不同尺度的输入时能保持一致的表现,这对于处理具有周期性结构的图像特别有用。
- 适应高分辨率图像:FFC层的特性使得它在处理高分辨率图像时依然能保持高效和高质量的输出,即使是在训练过程中没有见过如此高分辨率的数据。
- 易于集成:FFC是完全可微分的,可以作为传统卷积层的直接替代品,这意味着它可以轻松地集成到现有的深度学习架构中,而不需要重大的架构调整。
- 捕捉周期性结构:FFC特别擅长捕捉和重建图像中的周期性结构,这是许多人类制造的环境中常见的特征,比如砖墙、梯子、窗户等。
- 减少计算浪费:在传统的卷积层中,很多计算可能被用于等待信息从远处传播,而FFC层则可以直接利用全局信息,减少了这种等待,从而提高了计算效率。
其中的傅立叶变换模块
SpectralTransform
FourierUnit
class FFC(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
ratio_gin, ratio_gout, stride=1, padding=0,
dilation=1, groups=1, bias=False, enable_lfu=True,
padding_type='reflect', gated=False, **spectral_kwargs):
super(FFC, self).__init__()
assert stride == 1 or stride == 2, "Stride should be 1 or 2."
self.stride = stride
in_cg = int(in_channels * ratio_gin)
in_cl = in_channels - in_cg
out_cg = int(out_channels * ratio_gout)
out_cl = out_channels - out_cg
#groups_g = 1 if groups == 1 else int(groups * ratio_gout)
#groups_l = 1 if groups == 1 else groups - groups_g
self.ratio_gin = ratio_gin
self.ratio_gout = ratio_gout
self.global_in_num = in_cg
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
self.convl2l = module(in_cl, out_cl, kernel_size,
stride, padding, dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
self.convl2g = module(in_cl, out_cg, kernel_size,
stride, padding, dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
self.convg2l = module(in_cg, out_cl, kernel_size,
stride, padding, dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
self.convg2g = module(
in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
self.gated = gated
module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
self.gate = module(in_channels, 2, 1)
def forward(self, x):
x_l, x_g = x if type(x) is tuple else (x, 0)
out_xl, out_xg = 0, 0
if self.gated:
total_input_parts = [x_l]
if torch.is_tensor(x_g):
total_input_parts.append(x_g)
total_input = torch.cat(total_input_parts, dim=1)
gates = torch.sigmoid(self.gate(total_input))
g2l_gate, l2g_gate = gates.chunk(2, dim=1)
else:
g2l_gate, l2g_gate = 1, 1
if self.ratio_gout != 1:
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
if self.ratio_gout != 0:
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
return out_xl, out_xg
在代码中只有两种情况,简化后的结果:
- 当 ratio_gin =0 ,ratio_gout =0 的时候,进行了下采样,out_channels = 2* in_channels
x_l
x_g:0
in_cg = 0
in_cl = in_channels
out_cg = 0
out_cl = out_channels
self.convl2l = nn.Conv2d(in_cl,out_cl)
self.convl2g = nn.Identity
self.convg2l = nn.Identity
self.convg2g = nn.Identity
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
- 当 ratio_gin =0.75 ,ratio_gout =0.75 的时候,in_channels=out_channels=512,输入输出的shape保持不变:
x_l:(b,128,h/8,w/8)
x_g:(b,384,h/8,w/8)
in_cg = int(in_channels * ratio_gin) # 384
in_cl = in_channels - in_cg # 128
out_cg = int(out_channels * ratio_gout) # 384
out_cl = out_channels - out_cg # 128
self.convl2l = nn.Conv2d(in_cl,out_cl)
self.convl2g = nn.Conv2d(in_cl,out_cg)
self.convg2l = nn.Conv2d(in_cg,out_cl)
self.convg2g = nn.SpectralTransform(in_cg,out_cg)
out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
在FFC上增加了BN和ACT层
class FFC_BN_ACT(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, ratio_gin, ratio_gout,
stride=1, padding=0, dilation=1, groups=1, bias=False,
norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
padding_type='reflect',
enable_lfu=True, **kwargs):
super(FFC_BN_ACT, self).__init__()
self.ffc = FFC(in_channels, out_channels, kernel_size,
ratio_gin, ratio_gout, stride, padding, dilation,
groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
lnorm = nn.Identity if ratio_gout == 1 else norm_layer
gnorm = nn.Identity if ratio_gout == 0 else norm_layer
global_channels = int(out_channels * ratio_gout)
self.bn_l = lnorm(out_channels - global_channels)
self.bn_g = gnorm(global_channels)
lact = nn.Identity if ratio_gout == 1 else activation_layer
gact = nn.Identity if ratio_gout == 0 else activation_layer
self.act_l = lact(inplace=True)
self.act_g = gact(inplace=True)
def forward(self, x):
x_l, x_g = self.ffc(x)
x_l = self.act_l(self.bn_l(x_l))
x_g = self.act_g(self.bn_g(x_g))
return x_l, x_g
FFC残差模块,两个FFC_BN_ACT层作为残差。
class FFCResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
spatial_transform_kwargs=None, inline=False, **conv_kwargs):
super().__init__()
self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
norm_layer=norm_layer,
activation_layer=activation_layer,
padding_type=padding_type,
**conv_kwargs)
self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
norm_layer=norm_layer,
activation_layer=activation_layer,
padding_type=padding_type,
**conv_kwargs)
if spatial_transform_kwargs is not None:
self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
self.inline = inline
def forward(self, x):
if self.inline:
x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
else:
x_l, x_g = x if type(x) is tuple else (x, 0)
id_l, id_g = x_l, x_g
x_l, x_g = self.conv1((x_l, x_g))
x_l, x_g = self.conv2((x_l, x_g))
x_l, x_g = id_l + x_l, id_g + x_g
out = x_l, x_g
if self.inline:
out = torch.cat(out, dim=1)
return out