文章目录
paper:SCConv: Spatial and Channel Reconstruction Convolution for Feature Redundancy
1、Spatial and Channel reconstruction Convolution
由于现有的 CNN 模型中存在的特征冗余问题,传统的卷积操作会产生大量冗余的特征,这会导致模型计算量和存储空间的需求增加。为了解决这个问题,论文提出了一种 空间通道重建卷积(Spatial and Channel reconstruction Convolution),通过联合减少空间和通道冗余,在降低模型复杂度和计算成本的同时,提升模型的性能。SCConv 由两个部件组成:空间重构单元(SRU)和通道重构单元(CRU)。
SCConv的构成并不复杂,结构清晰,对于一个输入 X 来说,依次经过 SRU 与 CRU 处理即可。
SCConv 结构图:
2、Spatial Reconstruction Unit
SRU 利用 Group Normalization 层中的缩放因子来评估特征图的空间信息含量。并通过分离和重建操作,抑制特征图在空间维度上的冗余,增强特征表示能力。具体步骤如下:
- 标准化: 首先,使用 Group Normalization 层对输入特征图进行标准化处理,消除特征图在不同通道之间的尺度差异。并利用缩放因子计算每个特征图的重要性权重。
- 门控操作: 将权重值映射到 (0, 1) 范围,并通过门控操作将特征图分为 informative 和 non-informative 两部分。
- 重建操作: 将 informative 特征图与 non-informative 特征图进行交叉重建,并通过拼接操作得到空间精炼的特征图。
SRU 结构图:
3、Channel Reconstruction Unit
CRU 的核心思想是利用分割-转换-融合策略,通过轻量级的卷积操作提取丰富的特征信息,并通过特征复用方案减少计算成本和存储空间。其具体实现步骤如下:
- 分割操作: 将输入特征图的通道分为两部分,分别进行操作。分割比例 α 可以根据实际情况进行调整。
- 上层转换: 对一部分特征图进行 GWC 和 PWC 操作,提取丰富的高层次特征信息。GWC 和 PWC 的优势在于:GWC: 减少参数量和计算量,但会切断通道组之间的信息流。PWC: 补偿信息损失,并促进通道之间的信息流动。
- 下层转换: 对另一部分特征图进行 PWC 操作,提取补充的细节信息,并通过特征复用方案获取更多特征图。
- 融合操作: 利用全局平均池化收集全局空间信息,并使用通道软注意力操作生成特征重要性向量,最后根据特征重要性向量融合特征图。
CRU 结构图:
4、代码实现
import torch
import torch.nn.functional as F
import torch.nn as nn
class GroupBatchnorm2d(nn.Module):
def __init__(self, c_num: int,
group_num: int = 16,
eps: float = 1e-10
):
super(GroupBatchnorm2d, self).__init__()
assert c_num >= group_num
self.group_num = group_num
self.weight = nn.Parameter(torch.randn(c_num, 1, 1))
self.bias = nn.Parameter(torch.zeros(c_num, 1, 1))
self.eps = eps
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, self.group_num, -1)
mean = x.mean(dim=2, keepdim=True)
std = x.std(dim=2, keepdim=True)
x = (x - mean) / (std + self.eps)
x = x.view(N, C, H, W)
return x * self.weight + self.bias
class SRU(nn.Module):
def __init__(self,
oup_channels: int,
group_num: int = 16,
gate_treshold: float = 0.5,
torch_gn: bool = True
):
super().__init__()
self.gn = nn.GroupNorm(num_channels=oup_channels, num_groups=group_num) if torch_gn else GroupBatchnorm2d(
c_num=oup_channels, group_num=group_num)
self.gate_treshold = gate_treshold
self.sigomid = nn.Sigmoid()
def forward(self, x):
gn_x = self.gn(x)
w_gamma = self.gn.weight / sum(self.gn.weight)
w_gamma = w_gamma.view(1, -1, 1, 1)
reweigts = self.sigomid(gn_x * w_gamma)
# Gate
w1 = torch.where(reweigts > self.gate_treshold, torch.ones_like(reweigts), reweigts) # 大于门限值的设为1,否则保留原值
w2 = torch.where(reweigts > self.gate_treshold, torch.zeros_like(reweigts), reweigts) # 大于门限值的设为0,否则保留原值
x_1 = w1 * x
x_2 = w2 * x
y = self.reconstruct(x_1, x_2)
return y
def reconstruct(self, x_1, x_2):
x_11, x_12 = torch.split(x_1, x_1.size(1) // 2, dim=1)
x_21, x_22 = torch.split(x_2, x_2.size(1) // 2, dim=1)
return torch.cat([x_11 + x_22, x_12 + x_21], dim=1)
class CRU(nn.Module):
'''
alpha: 0<alpha<1
'''
def __init__(self,
op_channel: int,
alpha: float = 1 / 2,
squeeze_radio: int = 2,
group_size: int = 2,
group_kernel_size: int = 3,
):
super().__init__()
self.up_channel = up_channel = int(alpha * op_channel)
self.low_channel = low_channel = op_channel - up_channel
self.squeeze1 = nn.Conv2d(up_channel, up_channel // squeeze_radio, kernel_size=1, bias=False)
self.squeeze2 = nn.Conv2d(low_channel, low_channel // squeeze_radio, kernel_size=1, bias=False)
# up
self.GWC = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=group_kernel_size, stride=1,
padding=group_kernel_size // 2, groups=group_size)
self.PWC1 = nn.Conv2d(up_channel // squeeze_radio, op_channel, kernel_size=1, bias=False)
# low
self.PWC2 = nn.Conv2d(low_channel // squeeze_radio, op_channel - low_channel // squeeze_radio, kernel_size=1,
bias=False)
self.advavg = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
# Split
up, low = torch.split(x, [self.up_channel, self.low_channel], dim=1)
up, low = self.squeeze1(up), self.squeeze2(low)
# Transform
Y1 = self.GWC(up) + self.PWC1(up)
Y2 = torch.cat([self.PWC2(low), low], dim=1)
# Fuse
out = torch.cat([Y1, Y2], dim=1)
out = F.softmax(self.advavg(out), dim=1) * out
out1, out2 = torch.split(out, out.size(1) // 2, dim=1)
return out1 + out2
class ScConv(nn.Module):
def __init__(self,
op_channel: int,
group_num: int = 4,
gate_treshold: float = 0.5,
alpha: float = 1 / 2,
squeeze_radio: int = 2,
group_size: int = 2,
group_kernel_size: int = 3,
):
super().__init__()
self.SRU = SRU(op_channel,
group_num=group_num,
gate_treshold=gate_treshold)
self.CRU = CRU(op_channel,
alpha=alpha,
squeeze_radio=squeeze_radio,
group_size=group_size,
group_kernel_size=group_kernel_size)
def forward(self, x):
x = self.SRU(x)
x = self.CRU(x)
return x
if __name__ == '__main__':
x = torch.randn(4, 512, 7, 7)
model = ScConv(512)
print(model(x).shape)
本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。