FB等提出全新卷积操作OctConv,速度接近理论极限

引言

论文地址
这篇论文是周一时带我的大佬(现在瑞士读博士,据说还在nips上面发过文章?,瑟瑟发抖)发给我一个一篇链接文章,博客是计划周五就要写出来的,但是由于要将maxnet的代码迁移到pytorch的resnet上面花费了一些时间。至今没见过这位大佬,我这位本科大白只是每周一阅读他发的论文和相关demo代码,改写或者迁移到现在的工业图像分类上。有想一起学习的可以加qq:1678354579进行讨论。
下面的内容由于时间有限,主要以代码实现为主。才疏学浅,如果那些错误还请大佬多多指正!

摘要

在自然图像中,信息总是在不同频率中表达的,其中高频信号一般包含丰富的细节而低频信号一般包含整体的结构。类似地,卷积层的输出特征图同样可以被看作是混合了不同频域的信息。在这项工作中,我们提出了如何根据频域去分解信息混合的特征图,并设计了一个新颖的八度卷积(Octave Convolution,OctConv)操作来保存和处理那些在较低空间分辨率下变化“较慢”(Slower)的特征图,从而减少存储和计算开销。与现有多尺度(multi-scale)方法不同的是,八度卷积被制定为一种单个通用的即插即用卷积单元,可以直接替换普通(vanilla)卷积而不需要对现有网络有任何调整。它同时也是对一些表明有着更好拓扑(topologies)或者减少通道冗余的方法的补充,并且与这些方法正交(orthogonal)。通过简单地用八度卷积替换普通卷积,我们在实验中发现我们在减少存储和计算开销的同时,还能持续提高图像和视频识别任务的准确率。一个使用八度卷积的ResNet-152网络能够在ImageNet上达到82.9%的Top-1分类准确率,而其浮点计算量仅仅只有22.2G(Giga)。

  • 总结下来就是:自然界的图像中高频的信息表示细腻而丰富的细节,低频表示整体的轮廓和布局。八度卷积最大的优点就是节省存储空间的运算力,而且有怎么如此强的功能只需要改动网络中卷积部分即可实现即插即用的功能!我的代码能力一般,大概花了一天左右的时间改写了octconv版的resnet,后期经过改动能够适应三种卷积的增强版
  • 加一句,关于低频和高频个人觉得可能搞美术的人更能理解。比如像画人物一样,大致的轮廓是差不多的,不经常改变为低频。具体的细节,一颦一动每个人都不一样为高频。本人为工科宅男一枚,献丑了?

原理浅谈

关于详细的原理,大家可以参考论文和一片中文博客。我这里更深的理解也是来源这篇博客,推荐大家去看看。
这里我主要从个人代码理解和实现的角度来聊一聊原理,说白了就是数学公式看的有点蒙逼。代码和公式相结合能够理解更深入。
传统的图像卷积是每一个卷积核为[kernel_size,kernel_size,in_channels],通过一系列相乘相加操作后得出特征图的一个像素点。如果是BP网络这一步就已经结束了,但是卷积网络会利用stride进行移动相同的卷积核得出下一个像素点。就这样按照步长在图像的宽高进行移动,得出一个通道的特征图,那如果我想要out_channels个通道的特征图。我只需要out_channels个卷积和就可以了,所以卷积的参数维度就是[kernel_size,kernel_size,in_channels,out_channels]。后期人们在消除特征图的冗余,人们又提出了grop_conv和depth_wise的卷积,对应的网络就是现在的resenxt和mobilenet。关于冗余的理解之前看过一本书上讲解是过多的输出通道,卷积核很大概率存在相似性,那么输出的特征图就会存在线性相关(简单说就是特征图的一个向量可以由另一个向量线性表示)。这部分如果大家有感到不太懂的,自动google关键字。或者加我私聊,欢迎骚扰!

好像有点扯远了,,,,现在开始进入重点啦!!八度卷积是在分辨率的维度提出低频的信息在传统的卷积中也存在冗余,通过将特征图分离成低频信息(低分辨率),高频信息(高分辨率)的达到节省存储和算力。大概估算一下,如果每一个特征图的一半为低频信息,那么他的分辨率降低为原始特征图的1/2,存储会卷积运算会减少1/4。
下采样刚才我们降低冗余是通过降低低频信息的分辨率,那么现在的问题是如何进行分辨率的降低呢?卷积网络中有两种下采样的方式,一种是池化(pool),一种是步长为2的卷积。论文的实验是说池化的方式会更有效
消融实验
将八度卷积嵌入到resnet中发现stride=2的卷积下采样并没有降低可训练的参数,而pool的下采样方式则数十倍的降低了参数量。具体的数值当时没有保存,应该会降低的更过。pool我们好理解,因为pool本来并没有可训练卷积,而stride=2的卷积下采样本质是将原始的卷积核分解成四份(中间卷积)或者两份(开始和结尾卷积),所以他的可训练参数是不会减少的。
八度卷积路线图
第一层卷积:输入图像默认全部为高频信息,故alpha_int=0,alpha_out= α \alpha α
在这里插入图片描述
中间层卷积,特征图包含低频和高频信息,一般设置为alpha_int=alpha_out= α \alpha α
在这里插入图片描述
最后一层卷积,回复正常特征图,故alpha_int= α \alpha α,alpha_out=0
在这里插入图片描述
这里的参数设置 α \alpha α一般为0.5,0.2。具体的参数设置会根据图像的特征丰富程度调整。
简单总结:特征图由第一层进入分为两路(低频信息和高频信息),中间层一直是两路信息,并且两路信息之间有交互,最终汇聚为一路信息输出。

具体实现代码

版本一 pool池化

# -*- coding: utf-8 -*-
# @Time    : 2019/4/22 13:29
# @Author  : ljf
import torch
import torch.nn.functional as F
from torch import nn


class OctConv2d_v1(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 alpha_in=0.5,
                 alpha_out=0.5
                 ):
        """adapt first octconv , octconv and last octconv

        """
        assert alpha_in >= 0 and alpha_in <= 1, "the value of alpha_in should be in range of [0,1],but get {}".format(
            alpha_in)
        assert alpha_out >= 0 and alpha_out <= 1, "the value of alpha_in should be in range of [0,1],but get {}".format(
            alpha_out)
        super(OctConv2d_v1, self).__init__(in_channels,
                                        out_channels,
                                        dilation,
                                        groups,
                                        bias,)
        self.alpha_in = alpha_in
        self.alpha_out = alpha_out
        self.kernel_size = (1,1)
        self.stride = (1,1)
        self.avgPool = nn.AvgPool2d(kernel_size, stride, padding)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.inChannelSplitIndex = int(
            self.alpha_in * self.in_channels)
        self.outChannelSplitIndex = int(
            self.alpha_out * self.out_channels)
        # split bias
        if bias:
            self.hh_bias = self.bias[self.outChannelSplitIndex:]
            self.hl_bias = self.bias[:self.outChannelSplitIndex]
            self.ll_bias = self.bias[ :self.outChannelSplitIndex]
            self.lh_bias = self.bias[ self.outChannelSplitIndex:]
        else:
            self.hh_bias = None
            self.hl_bias = None
            self.ll_bias = None
            self.ll_bias = None

        # conv and upsample
        self.upsample = F.interpolate

    def forward(self, x):
        if not isinstance(x, tuple):
            # first octconv
            input_h = x if self.alpha_in == 0 else None
            input_l = x if self.alpha_in == 1 else None
        else:
            input_l = x[0]
            input_h = x[1]

        output = [0, 0]
        # H->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
            output_hh = F.conv2d(self.avgPool(input_h),
                                 self.weight[
                                 self.outChannelSplitIndex:,
                                 self.inChannelSplitIndex:,
                                 :, :],
                                 self.bias[self.outChannelSplitIndex:],
                                 self.kernel_size
                                 )

            output[1] += output_hh

        # H->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
            output_hl = F.conv2d(self.avgpool(self.avgPool(input_h)),
                                 self.weight[
                :self.outChannelSplitIndex,
                self.inChannelSplitIndex:,
                                     :, :],
                                 self.bias[:self.outChannelSplitIndex],
                                 self.kernel_size
                                 )

            output[0] += output_hl

        # L->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
            output_ll = F.conv2d((self.avgPool(input_l)),
                                 self.weight[
                                 :self.outChannelSplitIndex,
                                 :self.inChannelSplitIndex,
                                 :, :],
                                 self.bias[:self.outChannelSplitIndex],
                                 self.kernel_size
                                 )

            output[0] += output_ll

        # L->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
            output_lh = F.conv2d(self.avgPool(input_l),
                                 self.weight[
                                 self.outChannelSplitIndex:,
                                 :self.inChannelSplitIndex,
                                 :, :],
                                 self.bias[self.outChannelSplitIndex:],
                                 self.kernel_size
                                 )
            output_lh = self.upsample(output_lh, scale_factor=2)

            output[1] += output_lh

        if isinstance(output[0], int):
            out = output[1]
        else:
            out = tuple(output)
        return out
if __name__ == "__main__":
    input = torch.randn(1, 3, 32, 32)
    octconv1 = OctConv2d(
        in_channels=3,
        out_channels=6,
        kernel_size=3,
        padding=1,
        stride=2,
        alpha_in=0,
        alpha_out=0.3)
    octconv2 = OctConv2d(
        in_channels=6,
        out_channels=16,
        kernel_size=2,
        padding=0,
        stride=2,
        alpha_in=0.3,
        alpha_out=0.5)
    lastconv = OctConv2d(
        in_channels=16,
        out_channels=32,
        kernel_size=2,
        padding=0,
        stride=2,
        alpha_in=0.5,
        alpha_out=0)
    # bn1 = OctBN(3,3)
    # ac1 = OctAc(name="relu")
    out = octconv1(input)
    print(len(out))
    print(out[0].size())
    print(out[1].size())
    out = octconv2(out)
    print(len(out))
    print(out[0].size())
    print(out[1].size())

    out = lastconv(out)
    print(len(out))
    print(out[0].size())
    print(out[1])

版本二 stride=2的卷积

# -*- coding: utf-8 -*-
# @Time    : 2019/4/22 10:35
# @Author  : ljf
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class OctConv2d_v2(nn.Conv2d):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=True,
            alpha_in=0.5,
            alpha_out=0.5,):
        assert alpha_in >= 0 and alpha_in <= 1
        assert alpha_out >= 0 and alpha_out <= 1
        super(OctConv2d_v2, self).__init__(in_channels, out_channels,
                                           kernel_size, stride, padding,
                                           dilation, groups, bias)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.alpha_in = alpha_in
        self.alpha_out = alpha_out
        self.inChannelSplitIndex = math.floor(
            self.alpha_in * self.in_channels)
        self.outChannelSplitIndex = math.floor(
            self.alpha_out * self.out_channels)
        if bias:
            self.hh_bias = self.bias[self.outChannelSplitIndex:]
            self.hl_bias = self.bias[:self.outChannelSplitIndex]
            self.ll_bias = self.bias[ :self.outChannelSplitIndex]
            self.lh_bias = self.bias[ self.outChannelSplitIndex:]
        else:
            self.hh_bias = None
            self.hl_bias = None
            self.ll_bias = None
            self.lh_bias = None
    def forward(self, input):
        if not isinstance(input, tuple):
            assert self.alpha_in == 0 or self.alpha_in == 1
            inputLow = input if self.alpha_in == 1 else None
            inputHigh = input if self.alpha_in == 0 else None
        else:
            inputLow = input[0]
            inputHigh = input[1]

        output = [0, 0]
        # H->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != self.in_channels:
            outputH2H = F.conv2d(
                inputHigh,
                self.weight[
                    self.outChannelSplitIndex:,
                    self.inChannelSplitIndex:,
                    :,
                    :],
                self.hh_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[1] += outputH2H

        # H->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != self.in_channels:
            outputH2L = F.conv2d(
                self.avgpool(inputHigh),
                self.weight[
                    :self.outChannelSplitIndex,
                    self.inChannelSplitIndex:,
                    :,
                    :],
                self.hl_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[0] += outputH2L

        # L->L
        if self.outChannelSplitIndex != 0 and self.inChannelSplitIndex != 0:
            outputL2L = F.conv2d(
                inputLow,
                self.weight[
                    :self.outChannelSplitIndex,
                    :self.inChannelSplitIndex,
                    :,
                    :],
                self.ll_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[0] += outputL2L

        # L->H
        if self.outChannelSplitIndex != self.out_channels and self.inChannelSplitIndex != 0:
            outputL2H = F.conv2d(
                F.interpolate(inputLow, scale_factor=2),
                self.weight[
                    self.outChannelSplitIndex:,
                    :self.inChannelSplitIndex,
                    :,
                    :],
                self.lh_bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups)
            output[1] += outputL2H
        if isinstance(output[0],int):
            out = output[1]
        else:
            out = tuple(output)
        return out


if __name__ == "__main__":
    input = torch.randn(1, 3, 32, 32)
    octconv1 = OctConv2d(in_channels=3,
                         out_channels=6,
                         kernel_size=3,
                         stride=2,
                         padding=1,
                         dilation=1,
                         groups=1,
                         bias=True,
                         alpha_in=0.,
                         alpha_out=0.25)
    octconv2 = OctConv2d(in_channels=6,
                         out_channels=16,
                         kernel_size=3,
                         stride=1,
                         padding=1,
                         dilation=1,
                         groups=1,
                         bias=True,
                         alpha_in=0.25,
                         alpha_out=0.5)
    out = octconv1(input)
    print(len(out))
    print(out[0].shape)
    print(out[1].size())

    out = octconv2(out)
    print(len(out))
    print(out[0].size())
    print(out[1].size())

github地址

功力有限,还请各位多多包涵,多多指证。
参考文章:https://mp.weixin.qq.com/s?__biz=MzUyMjE2MTE0Mw==&mid=2247487810&idx=1&sn=1428510ec154a24a9e779d82f693930d&chksm=f9d14fdacea6c6cc42a630e57726c1789a54dc8e31616bd747fb2c35f41dbbd86f2c2a0b8998&mpshare=1&scene=23&srcid=#rd

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值