FcaNet代码理解——ICCV2021

原文链接: https://arxiv.org/abs/2012.11879
补充链接: https://arxiv.org/pdf/2012.11879v3.pdf
代码链接: https: //github.com/cfzd/FcaNet
在这里插入图片描述
FcaNet主要更新SE模块中squeeze部分,将SE模块中的全局平均池化,替换为DCT指导的全局池化。
全局平均池化: 计算每层通道上像素点的算术平均值。
DCT指导的全局池化: 计算每层通道上像素点的加权平均值。每个像素点的权重由2D DCT的基函数计算。下面会讲什么是2D DCT的基函数。

核心代码:

def get_freq_indices(method):
    assert method in ['top1','top2','top4','top8','top16','top32',
                      'bot1','bot2','bot4','bot8','bot16','bot32',
                      'low1','low2','low4','low8','low16','low32']
    num_freq = int(method[3:])
    if 'top' in method:
        all_top_indices_x = [0,0,6,0,0,1,1,4,5,1,3,0,0,0,3,2,4,6,3,5,5,2,6,5,5,3,3,4,2,2,6,1]
        all_top_indices_y = [0,1,0,5,2,0,2,0,0,6,0,4,6,3,5,2,6,3,3,3,5,1,1,2,4,2,1,1,3,0,5,3]
        mapper_x = all_top_indices_x[:num_freq]
        mapper_y = all_top_indices_y[:num_freq]
    elif 'low' in method:
        all_low_indices_x = [0,0,1,1,0,2,2,1,2,0,3,4,0,1,3,0,1,2,3,4,5,0,1,2,3,4,5,6,1,2,3,4]
        all_low_indices_y = [0,1,0,1,2,0,1,2,2,3,0,0,4,3,1,5,4,3,2,1,0,6,5,4,3,2,1,0,6,5,4,3]
        mapper_x = all_low_indices_x[:num_freq]
        mapper_y = all_low_indices_y[:num_freq]
    elif 'bot' in method:
        all_bot_indices_x = [6,1,3,3,2,4,1,2,4,4,5,1,4,6,2,5,6,1,6,2,2,4,3,3,5,5,6,2,5,5,3,6]
        all_bot_indices_y = [6,4,4,6,6,3,1,4,4,5,6,5,2,2,5,1,4,3,5,0,3,1,1,2,4,2,1,1,5,3,3,3]
        mapper_x = all_bot_indices_x[:num_freq]
        mapper_y = all_bot_indices_y[:num_freq]
    else:
        raise NotImplementedError
    return mapper_x, mapper_y

class MultiSpectralAttentionLayer(torch.nn.Module):
    def __init__(self, channel, dct_h, dct_w, reduction = 16, freq_sel_method = 'top16'):
        super(MultiSpectralAttentionLayer, self).__init__()
        self.reduction = reduction
        self.dct_h = dct_h
        self.dct_w = dct_w

        mapper_x, mapper_y = get_freq_indices(freq_sel_method)
        self.num_split = len(mapper_x)
        mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x] 
        mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
        # make the frequencies in different sizes are identical to a 7x7 frequency space
        # eg, (2,2) in 14x14 is identical to (1,1) in 7x7

        self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        n,c,h,w = x.shape
        x_pooled = x
        if h != self.dct_h or w != self.dct_w:
            x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w))
            # If you have concerns about one-line-change, don't worry.   :)
            # In the ImageNet models, this line will never be triggered. 
            # This is for compatibility in instance segmentation and object detection.
        y = self.dct_layer(x_pooled)

        y = self.fc(y).view(n, c, 1, 1)
        return x * y.expand_as(x)


class MultiSpectralDCTLayer(nn.Module):
    """
    Generate dct filters
    """
    def __init__(self, height, width, mapper_x, mapper_y, channel):
        super(MultiSpectralDCTLayer, self).__init__()
        
        assert len(mapper_x) == len(mapper_y)
        assert channel % len(mapper_x) == 0

        self.num_freq = len(mapper_x)

        # fixed DCT init
        self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
        
        # fixed random init
        # self.register_buffer('weight', torch.rand(channel, height, width))

        # learnable DCT init
        # self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
        
        # learnable random init
        # self.register_parameter('weight', torch.rand(channel, height, width))

        # num_freq, h, w

    def forward(self, x):
        assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))
        # n, c, h, w = x.shape

        x = x * self.weight

        result = torch.sum(x, dim=[2,3])
        return result

    def build_filter(self, pos, freq, POS):
        result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) 
        if freq == 0:
            return result
        else:
            return result * math.sqrt(2)
    
    def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
        dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)

        c_part = channel // len(mapper_x)

        for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
            for t_x in range(tile_size_x):
                for t_y in range(tile_size_y):
                    dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter(t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y)
                        
        return dct_filter

1 离散余弦变换

二维离散余弦变换:
在这里插入图片描述

其中, f 2 d ∈ R H × W f^{2d} \in R^{H \times W} f2dRH×W是2D DCT频谱图, x 2 d ∈ R H × W x^{2d} \in R^{H \times W} x2dRH×W是输入图像,H和W分别为 x 2 d x^{2d} x2d图像的宽和高。 ( i , j ) (i,j) (i,j)是空间域坐标, ( h , w ) (h,w) (h,w)是频域坐标。

这个公式里的变量就是空间域坐标 ( i , j ) (i,j) (i,j)和频域坐标 ( h , w ) (h,w) (h,w),选定一个频域坐标 ( h , w ) (h,w) (h,w),该公式可以理解为:
频域 ( h , w ) (h,w) (h,w)的值 = = =输入图像( x 2 d x^{2d} x2d)乘以 ( h , w ) (h,w) (h,w)处的权重图,再相加求和。

2 权重图的选择

权重图有多少种组合?哪些权重图有效?怎么筛选?为了解答答这些问题,作者做了大量工作。

权重图由上述二维离散余弦变换的余弦部分即为2D DCT的基函数获得,记为 B B B,公式如下:
在这里插入图片描述

2.1 权重图有多少种组合?

权重图最好能被所有尺寸的特征图整除,而ImageNet上最小特征尺寸是 7 × 7 7\times7 7×7,所以权重图的尺寸被设置为 7 × 7 7\times7 7×7,则H=W=6,h和w分别可以取0~6(7个值),则 ( h , w ) (h,w) (h,w) 7 × 7 = 49 7\times7=49 7×7=49种组合,即有49种权重图。
(a)为49种特征图,(b)为筛选出的评分高的特征图
(a)为49种特征图,(b)为筛选出的评分高的16种特征图。

2.2 哪些权重图有效?怎么筛选?

这里的权重图就是频率分量(frequency comonents)。将49种频率分量放到网络里逐一测试,挑选出Top-K个频率分量。最终作者确定了Top-16个频率分量构建其Multi-spectral channel attention模块。

下图为49次测试在ImageNet上获得的的Top-1准度,低频分量会获得更好的结果。
在这里插入图片描述
下表为频率分量个数对结果的影响,选取前16个频率分量的结果最好。
在这里插入图片描述

3 回看代码

def get_freq_indices(method):

  • 用来选取频率分量坐标,最终method设置为top16
  • 返回值mapper_xmapper_y是长度为16的列表,放置Top-16的 ( h , w ) (h,w) (h,w)

class MultiSpectralDCTLayer(nn.Module):

  • 用来获得频率向量,即为下图。
  • def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):获得和输入特征图相同尺寸的权重图,按通道均分为16份,每份放入一种权重图(一张权重图通道维度上复制C/16次为一份)。
  • def build_filter(self, pos, freq, POS):计算一个余弦值。
    在这里插入图片描述

class MultiSpectralAttentionLayer(torch.nn.Module):

  • 为完整的Multi-spectral channel attention,可插入现成的网络中,如ResNet等

3.1 细节1:频率坐标缩放

没有将特征图切片为 7 × 7 7\times7 7×7大小,而是将频率坐标(temp_x, temp_y)根据输入特征尺寸(dct_h, dct_w)缩放。

mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x] 
mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
# make the frequencies in different sizes are identical to a 7x7 frequency space
# eg, (2,2) in 14x14 is identical to (1,1) in 7x7

4 其他

看这篇论文的时候发现自己对频谱图的认识非常薄弱,查找了一些相关资料来辅助自己的理解。

4.1 傅里叶变换频谱图和DCT频谱图

2D 傅里叶变换频谱图特点:
频谱图以该图像的中心为圆心:

  • 圆的半径对应频率的高低。低频半径小,高频半径大,中心为直流分量
  • 圆的相位对应原图频率分量的相位。(通俗理解,不严谨)
  • 频谱图的灰度值对应该频率的能量高低
    小节参考:https://zhuanlan.zhihu.com/p/454090354

2D DCT频谱图特点:

  • 低频信息集中在矩阵的左上角。
  • 高频信息则向右下角集中。
  • 直流分量在[0,0]处。
  • [0,1]处的余弦部分在一个方向上是一个半周期的余弦函数,在另一个方向上是一个常数。[1,0]处的余弦部分与[0,1]类似,只不过方向旋转了90度。
    小节参考:《食用数学信号处理-从原理到应用》,人民邮电出版社

这本书里对DCT基函数进行了比较清晰的可视化,如下图:
在这里插入图片描述

  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值