深度学习之图像分类(三十)-- Hire-MLP网络详解

深度学习之图像分类(三十)Hire-MLP网络详解

一晃都学习了三十个网络了,时间过得真快。本次学习华为提出的 Hire-MLP,依然是通过旋转特征图,将不同位置的特征对齐到同一个通道上从而实现 MLP-based Model 中的局部感受野依赖。

请添加图片描述

1. 前言

本此学习华为诺亚&北大&悉尼大学联合提出的 HireMLP。关于 MLP-Mixer 的两大主要问题我们已经在前面的学习中阐述了很多遍:对于 Token 进行全局感受野卷积,容易过拟合且对图像尺寸敏感,无法作为 Backbone 用于下游任务。Hire-MLP 也提出:如何在 MLP-based 模型中结合局部感受野和全局感受野,同时对图像输入分辨率不敏感是值得探索的地方。为了对输入图像尺寸不敏感,要不使用固定尺寸小卷积核(例如 3 × 3 3 \times 3 3×3 的DW Conv);或者就是移动特征图,使得不同 token 对齐到同一个通道上,然后使用 1 × 1 1 \times 1 1×1 卷积来实现局部感受野,即不同 token 之间的信息融合。Hire-MLP 也是一样的思路,通过移动特征图(本工作称之为区域重排 Region Rearrangement),从而实现局部信息的整合。本工作的原始论文为 Hire-MLP: Vision MLP via Hierarchical Rearrangement。2021.8.30 挂上 arXiv,代码并未开源。最终在 ImageNet 上达到了 83.4% 的 Top-1 精度,这与 SOTA 的 Swin Transformer 等相当。虽然 Hire-MLP 结构上对下游任务友好,但是本文并没有将其应用于下游任务,从而不清楚其真实性能。

请添加图片描述

从结果可见:

  • Hire-MLP-S 取得了 81.8% 的精度,而计算量仅为 4.2G Flops,优于其他 MLP 方案。相比 AS-MLP、CycleMLP,所提 Hire-MLP 性能更佳。
  • Hire-MLP-B 与 Hire-MLP-L 分别取得了 83.1% 与 83.4% 的精度,而计算量分别为 8.1G 与 13.5G。
  • 相比 DeiT、Swin 以及 PVT,所提方案具有更快的推理速度;
  • 相比 RegNetY,所提方案具有更高的精度,同时具有相似的模型大小和复杂度。
  • 相比AS-MLP,Hire-MLP好像并没有什么优势,性能相当,速度反而AS-MLP更快 。(AS-MLP 源代码的 Shift 操作是用 cupy 库自己编程实现的,可能会有影响)

2. Hire-MLP

这次讲解我先讲局部 Hire-MLP Block 结构,再描述网络的整体结构。

2.1 Hire-MLP Block

单个 Hire-MLP Block 依然是分为 Token-mixing MLP 和 Channel mixing MLP,其中作者主要的贡献点在于替换 MLP-mixer 的 Token-mixing MLP 为 Hire-Module。所以整个 Hire-MLP Block 可以描述为:
Y =  Hire-Module  ( LN ⁡ ( X ) ) + X Z =  Channel-MLP  ( LN ⁡ ( Y ) ) + Y \begin{aligned} &Y=\text { Hire-Module }(\operatorname{LN}(X))+X \\ &Z=\text { Channel-MLP }(\operatorname{LN}(Y))+Y \end{aligned} Y= Hire-Module (LN(X))+XZ= Channel-MLP (LN(Y))+Y
Channel MLP 就是最普通的两层全连接,中间使用 GELU 激活函数,第一层全连接结点个数一般为输入结点个数的 2,3 或者 4 倍。或者可以直接称之为通道方向的 1 × 1 1 \times 1 1×1 卷积。

Hire-Module 的内部结构如下图所示,其中包含三条支路,分别是对于 H 方向的重排,W 方向的重排,以及通道方向的映射:

请添加图片描述

重排分为两类:Cross-Region 以及 Inner-Region

2.1.1 Inner-Region

Height-direction 的 Inner-Region 其实就是对特征图进行 H 维度的分组,然后将其切分开之后堆叠到通道维度。这里 H H H 为原始特征图的高度, h h h 为每一小组内特征图的高度。即对于一个 H × W × C H \times W \times C H×W×C 的特征图可以分成 g = H / h g = H / h g=H/h 组,每组的特征图大小为 h × W × C h \times W \times C h×W×C,然后对特征图进行重排得到 g × W × ( h C ) g \times W \times (hC) g×W×(hC),此后做 h C − > h C hC->hC hC>hC 的映射。映射结束后,再还原到 H × W × C H \times W \times C H×W×C 的特征图即可。可视化流程如下所示:

请添加图片描述

具体的映射其实是依赖两个全连接和一个 GELU 激活函数构成的,论文中作者将第一个全连接层的阶段设定为 C / 2 C / 2 C/2 用于降维和减少计算量。其实也可以只依赖两个 1 × 1 1 \times 1 1×1 卷积加以实现。

Inner-Region 重排其实可以只需要依赖 einops 库的 Rearrange 即可方便完成:

import torch
from torch import nn
from einops.layers.torch import Rearrange, Reduce

class InnerRegionW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b c h (w group) -> b (c w) h group', w = self.w)		# 重排
        )

    def forward(self, x):
        return self.region(x)

class InnerRegionRestoreW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b (c w) h group -> b c h (w group)', w = self.w)		# 恢复
        )

    def forward(self, x):
        return self.region(x)

model1 = InnerRegionW(w = 2)
model2 = InnerRegionRestoreW(w = 2)
images = torch.randn(1, 1, 4, 4)

print(images)
print("==============================")

with torch.no_grad():
    output1 = model1(images)
    output2 = model2(output1)
print("=========== output1 ==============")
print(output1)
print("=========== output2 ==============")
print(output2)
2.1.2 Cross-Region

单个方向的 Inner-Region 产生的是特定的一维线性感受野,即将 h × w h \times w h×w 拆分成 1 × w 1 \times w 1×w h × 1 h \times 1 h×1 的形式,由于输入是 224 × 224 224 \times 224 224×224 的方图,所以一般配置中 h = w h = w h=w。并且 Inner-Region 的空间位置相对固定,即和 Swin 分窗一样,窗口固定了,则窗口处理始终对应的都是固定的区域。为了使得不同窗内的 Token 有交互,需要移动窗,或者说需要旋转特征图。这就引发了 Cross-Region。请联系 Swin 一起思考,则容易理解多了

请添加图片描述

Height-direction 的 Inner-Region 其实就是对特征图进行 H 维度的转动,转动的步长设定为 s s s ,通常 s s s 取 1 或者 2。Cross-Region 操作被加在 Inner-Region 前后。实际上并不是每个 Inner-Region 前后都添加 Cross-Region,因为每一个都添加一样的 Cross-Region 其实等于没有添加。所以原文中作者是隔一个添加一个,隔一个添加一个。To get a global receptive field, the cross-region rearrangement operations are inserted before the inner-region rearrangement operation every two blocks. 但是作者将这个称为全局感受野我其实是不太认同的,毕竟 s s s 太小了,只不过边界地区挨到一起来罢了,这其实也是局部感受野而已

Cross-Region 重排其实可以只需要依赖 torch.roll 即可方便完成:

import torch
from torch import nn
from einops.layers.torch import Rearrange, Reduce

class CrossRegion(nn.Module):
    def __init__(self, step = 1, dim = 1):
        super().__init__()
        self.step = step
        self.dim = dim

    def forward(self, x):
        return torch.roll(x, self.step, self.dim)

model1 = CrossRegion(step = 2, dim = 3)
model2 = CrossRegion(step = -2, dim = 3)

model3 = CrossRegion(step = 1, dim = 2)
model4 = CrossRegion(step = -1, dim = 2)

images = torch.randn(1, 1, 5, 5)

print(images)
print("==============================")

with torch.no_grad():
    output1 = model1(images)
    output2 = model2(images)
    output3 = model3(images)
    output4 = model4(images)
print("=========== output1 ==============")
print(output1)
print("=========== output2 ==============")
print(output2)
print("=========== output3 ==============")
print(output3)
print("=========== output4 ==============")
print(output4)
2.1.3 特征融合

在 Hire-Module 中有三个并行的支路,第三条之路只需要通过一个 C − > C C -> C C>C 的单层全连接层映射。最后三条支路的特征图直接加和即可获得最终的融合特征图。其实用一下 Split-Attention 或者如 sMLPNet 说的通道拼接后经过 1 × 1 1 \times 1 1×1 卷积降维可能还会涨点,不过计算量会有所增加。所以最终不难分析得到 Hire-Module 是一个十字形的感受野。整个 Hire-MLP Block 的推理逻辑如下所示:

def forward(self, x):
    x_h = self.inner_regionH(self.cross_regionH(x))						# 重排
    x_w = self.inner_regionW(self.cross_regionW(x))						# 重排
        
    x_h = self.proj_h(x_h)												# 映射
    x_w = self.proj_w(x_w)												# 映射
    x_c = self.proj_c(x)												# 映射

    x_h = self.cross_region_restoreH(self.inner_region_restoreH(x_h))	# 恢复
    x_w = self.cross_region_restoreW(self.inner_region_restoreW(x_w))	# 恢复

    out = x_c + x_h + x_w												# 特征融合
    return out
2.1.4 HireMLP 和 ViP,AS-MLP 的区别?

HireMLP 和 ViP 有区别吗?在我看来没有本质区别,ViP 也是分组重排操作。ViP 更复杂地使用了 Split-Attention 和残差;HireMLP 则是将中间的映射改为两层全连接。HireMLP 作者说 Inspired by the shortcut in ResNet and ViP, an extra branch without spatial communication is alse added ...,但是看起来似乎不仅仅是像他们一样添加一个 extra branch without spatial communication 这么简单,而是整个模块其实思想都是惊人的一致。至于说指标报告方面 HireMLP 高了那么 0.2%,会不会是因为两层全连接导致的?尚不可而知。不过相比 ViP 的重排,能肉眼可见的改进就是这个地方了。

请添加图片描述

HireMLP 和 AS-MLP 有区别吗?在我看来没有本质区别,AS-MLP 也是十字形感受野,不过是位移的实现方式上有所不同。整体而言都是类似的。相比AS-MLP,Hire-MLP好像并没有什么优势,性能相当,速度反而AS-MLP更快。所以说:在特征图移动上玩花似乎已经走到头了,玩不出什么花来了。

请添加图片描述

2.2 整体网络结构

Hire-MLP 的 Patch Embedding 很有特色,使用卷积核大小为 7 × 7 7 \times 7 7×7 ,步长为 4 的卷积。相比而言 Swin 使用卷积核大小为 4 × 4 4 \times 4 4×4,步长为 4 的卷积。在近期的我自己的小实验中也发现:Patch Embedding 时具有重叠会更好,这样可以避免边界效应并在小数据集上提升性能。Hire-MLP 中间采用多阶段金字塔模型,总共分为 4 个阶段,每个阶段交替重复使用 Hire-MLP Block。下采样使用卷积核大小为 3 × 3 3 \times 3 3×3,步长为 2 的卷积,这样做也有重叠。最后经过全局池化后连接一个全连接分类器即可,其网络结构图如下所示:

请添加图片描述

作者一共提出来了四种配置:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RuWvCRJJ-1634726515046)(hiremlp-3.png)]

大家注意这四种配置,其中 Hire-MLP Block 中 h , w , s h,w,s h,w,s 的取值一样的。 h , w h,w h,w 指 Inner-Region 中分小组内 H 或者 W 维度的大小, s s s 指 Cross-Region 处旋转的步长。注意到 H H H 不一定整除 h h h,让我们以 Base 为例,作者在训练 ImageNet 的时候以 224 × 224 224 \times 224 224×224 的图片作为输入,经过第一个 Patch Embedding 之后,假设作者 kernel size = 7, stride = 4,假设使用了 padding = 3,则此后特征图大小为 56 × 56 × 64 56 \times 56 \times 64 56×56×64。此后在 Stage 1 中 h = w = 4 h = w = 4 h=w=4,即在 Inner-Region 中可以分为 56 / 4 = 14 56 / 4 = 14 56/4=14 组,特征图可以重排为 14 × 56 × 256 14 \times 56 \times 256 14×56×256,然后做 256 − > 32 − > 256 256 -> 32 -> 256 256>32>256 的映射 ( h C − > C / 2 − > h C hC -> C/2 -> hC hC>C/2>hC),看上去没什么问题。经过下采样的 Patch Embedding,kernel size = 3,stride = 2,假设使用了 padding = 1,则此后特征图大小为 28 × 28 × 128 28 \times 28 \times 128 28×28×128。这里 28 根本不是 3 的倍数,所以需要进行 padding 操作。作者通过对比发现 Circular padding 是最好的,但是 padding 类别其实对结果的影响并不大。

请添加图片描述

3. 消融实验

作者一共进行了多组消融实验:

  • Inner-Region 中 h , w h, w h,w 的影响:作者发现浅层使用大一点点的 h , w h,w h,w,深层使用小一点的 h , w h, w h,w 效果更好,最终使用 h = w = [ 4 , 3 , 3 , 2 ] h = w = [4,3,3,2] h=w=[4,3,3,2]

请添加图片描述

  • Cross-Region 中 s s s 的影响:作者发现浅层使用大一点点的 s s s,深层使用小一点的 s s s 效果更好,最终使用 s = [ 2 , 2 , 1 , 1 ] s = [2,2,1,1] s=[2,2,1,1]所以其实 Cross-Region 并没有实现大感受野信息交换

请添加图片描述

  • 不同 padding 策略的影响:最终发现padding 类别其实对结果的影响并不大,Circular padding 效果略好。

请添加图片描述

  • Hire Module 中不同模块的作用效果:可见去除 Inner-Region 或者 Cross-Region 性能都会下降,其中 Inner-Region 有关局部感受野,去除了只剩下 Cross-Region 其实等价于朴素的只有通道方向的映射,所以性能影响更大。但是只有通道方向的映射,真的作者能做到 79.81% 吗?我持保留意见

请添加图片描述

  • Cross-Region 的 shift 方式:作者比较了直接旋转的 Shifted manner 与 ShuffleNet 那样的分组的方式,最终发现 Shifted manner 会更好,这也符合直觉,因为图像中随机膈几列选一列的做法其实局部性并不保证了。可见局部性还是比较重要的。但是真的性能只差一丝丝吗

请添加图片描述

4. 总结与反思

Hire-MLP 其实依然是使用特征图移动,使得不同空间位置的 token 对齐到统一通道,然后使用通道方向的 1 × 1 1 \times 1 1×1 卷积实现的局部依赖性引入和对图像分辨率不敏感,这种思想在过去的诸多工作例如 AS-MLP 以及 ViP,S2MLPv2 等中均可看到,所以并不是什么新鲜的贡献。此外作者说 Hire-MLP 对下游任务友好,但是并没有进行实验让人有点点失望。本工作暂未开源,所报告性能我持怀疑态度,因为作者消融实验中去除 Inner-Region 只保留 Cross-Region 其实可看作仅有 Channel 方向 1 × 1 1 \times 1 1×1 卷积的网络,也能达到 79.81%?比 MLP-Mixer 的 76+% 还高

5. 代码

我自己实现的非官方 pytorch 代码见 此处,欢迎与大家自己复现的进行交流。

import torch
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce
from .utils import pair


class PreNormResidual(nn.Module):
    def __init__(self, dim, fn, norm = nn.LayerNorm):
        super().__init__()
        self.fn = fn
        self.norm = norm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

class PatchEmbedding(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, stride, padding, norm_layer=False):
        super().__init__()
        self.reduction = nn.Sequential(
                            nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding = padding),

                            nn.Identity() if (not norm_layer) else nn.Sequential(
                                Rearrange('b c h w -> b h w c'),
                                nn.LayerNorm(dim_out),
                                Rearrange('b h w c -> b c h w'),
                            )
                        )

    def forward(self, x):
        return self.reduction(x)

class FeedForward(nn.Module):
    def __init__(self, dim_in, hidden_dim, dim_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, hidden_dim, kernel_size = 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, dim_out, kernel_size = 1),
        )
    def forward(self, x):
        return self.net(x)

class CrossRegion(nn.Module):
    def __init__(self, step = 1, dim = 1):
        super().__init__()
        self.step = step
        self.dim = dim

    def forward(self, x):
        return torch.roll(x, self.step, self.dim)

class InnerRegionW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b c h (w group) -> b (c w) h group', w = self.w)
        )

    def forward(self, x):
        return self.region(x)

class InnerRegionH(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h
        self.region = nn.Sequential(
            Rearrange('b c (h group) w -> b (c h) group w', h = self.h)
        )

    def forward(self, x):
        return self.region(x)

class InnerRegionRestoreW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b (c w) h group -> b c h (w group)', w = self.w)
        )

    def forward(self, x):
        return self.region(x)

class InnerRegionRestoreH(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h
        self.region = nn.Sequential(
            Rearrange('b (c h) group w -> b c (h group) w', h = self.h)
        )

    def forward(self, x):
        return self.region(x)

class HireMLPBlock(nn.Module):
    def __init__(self, h, w, d_model, cross_region_step = 1, cross_region_id = 0, cross_region_interval = 2, padding_type = 'circular'):
        super().__init__()

        assert (padding_type in ['constant', 'reflect', 'replicate', 'circular'])
        self.padding_type = padding_type
        self.w = w
        self.h = h

        # cross region every cross_region_interval HireMLPBlock
        self.cross_region = (cross_region_id % cross_region_interval == 0)

        if self.cross_region:
            self.cross_regionW = CrossRegion(step = cross_region_step, dim = 3)
            self.cross_regionH = CrossRegion(step = cross_region_step, dim = 2)
            self.cross_region_restoreW = CrossRegion(step = -cross_region_step, dim = 3)
            self.cross_region_restoreH = CrossRegion(step = -cross_region_step, dim = 3)
        else:
            self.cross_regionW = nn.Identity()
            self.cross_regionH = nn.Identity()
            self.cross_region_restoreW = nn.Identity()
            self.cross_region_restoreH = nn.Identity()

        self.inner_regionW = InnerRegionW(w)
        self.inner_regionH = InnerRegionH(h)
        self.inner_region_restoreW = InnerRegionRestoreW(w)
        self.inner_region_restoreH = InnerRegionRestoreH(h)


        self.proj_h = FeedForward(h * d_model, d_model // 2, h * d_model)
        self.proj_w = FeedForward(w * d_model, d_model // 2, w * d_model)
        self.proj_c = nn.Conv2d(d_model, d_model, kernel_size = 1)
    
    def forward(self, x):
        x = x.permute(0, 3, 1, 2)

        B, C, H, W = x.shape
        padding_num_w = W % self.w
        padding_num_h = H % self.h
        x = nn.functional.pad(x, (0, self.w - padding_num_w, 0, self.h - padding_num_h), self.padding_type)

        x_h = self.inner_regionH(self.cross_regionH(x))
        x_w = self.inner_regionW(self.cross_regionW(x))
        
        x_h = self.proj_h(x_h)
        x_w = self.proj_w(x_w)
        x_c = self.proj_c(x)

        x_h = self.cross_region_restoreH(self.inner_region_restoreH(x_h))
        x_w = self.cross_region_restoreW(self.inner_region_restoreW(x_w))

        out = x_c + x_h + x_w

        out = out[:,:,0:H,0:W]
        out = out.permute(0, 2, 3, 1)
        return out

class HireMLPStage(nn.Module):
    def __init__(self, h, w, d_model_in, d_model_out, depth, cross_region_step, cross_region_interval, expansion_factor = 2, dropout = 0., pooling = False, padding_type = 'circular'):
        super().__init__()

        self.pooling = pooling
        self.patch_merge = nn.Sequential(
            Rearrange('b h w c -> b c h w'),
            PatchEmbedding(d_model_in, d_model_out, kernel_size = 3, stride = 2, padding=1, norm_layer=False),
            Rearrange('b c h w -> b h w c'),
        )

        self.model = nn.Sequential(
            *[nn.Sequential(
                PreNormResidual(d_model_in, nn.Sequential(
                    HireMLPBlock(
                        h, w, d_model_in, cross_region_step = cross_region_step, cross_region_id = i_depth + 1, cross_region_interval = cross_region_interval, padding_type = padding_type
                    )
                ), norm = nn.LayerNorm),
                PreNormResidual(d_model_in, nn.Sequential(
                    nn.Linear(d_model_in, d_model_in * expansion_factor),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model_in * expansion_factor, d_model_in),
                    nn.Dropout(dropout),
                ), norm = nn.LayerNorm),
            ) for i_depth in range(depth)]
        )

    def forward(self, x):
        x = self.model(x)
        if self.pooling:
            x = self.patch_merge(x)
        return x

class HireMLP(nn.Module):
    def __init__(
        self,
        patch_size=4,
        in_channels=3,
        num_classes=1000,
        d_model=[64, 128, 320, 512],
        h = [4,3,3,2],
        w = [4,3,3,2],
        cross_region_step = [2,2,1,1],
        cross_region_interval = 2,
        depth=[4,6,24,3],
        expansion_factor = 2,
        patcher_norm = False,
        padding_type = 'circular',
    ):
        patch_size = pair(patch_size)
        super().__init__()
        self.patcher = PatchEmbedding(dim_in = in_channels, dim_out = d_model[0], kernel_size = 7, stride = patch_size, padding = 3, norm_layer=patcher_norm)
        

        self.layers = nn.ModuleList()
        for i_layer in range(len(depth)):
            i_depth = depth[i_layer]
            i_stage = HireMLPStage(h[i_layer], w[i_layer], d_model[i_layer], d_model_out = d_model[i_layer + 1] if (i_layer + 1 < len(depth)) else d_model[-1],
                depth = i_depth, cross_region_step = cross_region_step[i_layer], cross_region_interval = cross_region_interval,
                expansion_factor = expansion_factor, pooling = ((i_layer + 1) < len(depth)), padding_type = padding_type)
            self.layers.append(i_stage)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_model[-1]),
            Reduce('b h w c -> b c', 'mean'),
            nn.Linear(d_model[-1], num_classes)
        )

    def forward(self, x):
        embedding = self.patcher(x)
        embedding = embedding.permute(0, 2, 3, 1)
        for layer in self.layers:
            embedding = layer(embedding)
        out = self.mlp_head(embedding)
        return out
  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

木卯_THU

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值