摘要
尽管多年来卷积网络一直是视觉任务的主要架构,但最近的实验表明,基于Transformer的模型,尤其是vision Transformer (ViT),在某些设置下可能会超过它们的性能。然而,由于Transformer 中自注意层的二次方运行时间(因为加入以单像素作为一个token,如果在每像素级别上应用 Transformer 中的自注意力层,它的计算成本将与每张图像的像素数成二次方扩展),ViT需要使用patch嵌入,将图像中的小区域组合成单个输入特征,以便应用于更大的图像尺寸。这就提出了一个问题:ViT的性能是由于固有的更强大的Transformer架构,还是某种程度上是因为使用patch作为输入表示实现的呢?
在本文中,我们为后者提供了一些证据:具体来说,我们提出了ConvMixer,这是一个非常简单的模型,在思路上类似于ViT和更基本的MLP-Mixer,因为这些模型直接对作为输入的patch操作,然后采用分离空间和通道维度的混合,并在整个网络中保持相同的大小和分辨率。然而,相比之下,ConvMixer只使用标准的卷积来实现混合步骤,空间方向使用单通道的组卷积。尽管它很简单,我们证明了ConvMixer在类似参数计数和数据集大小的情况下优于ViT、MLP-Mixer及其一些变体,此外还优于ResNet等经典视觉模型。我们的代码可以通过网址:GitHub - tmp-iclr/convmixer获得。
1 A Simple Model: ConvMixer
![](https://i-blog.csdnimg.cn/blog_migrate/54b0f3bfb066c8551c62be3a5cb0299c.png)
我们的模型,命名为ConvMixer,由一个patch embedding层和一个简单的全卷积块的重复堆叠组成。我们保持patch embeddings的空间结构,如图2所示。patchsize为以及embedding dimension为
的Patch embedding可以通过卷积实现,其中
为input channels,
为output channels, kernel size为
, 以及stride也为
的卷积来实现:
……(1)
例如一个 224 × 224 × 3 的自然图像作为输入,切分的 patch_size = 7,嵌入维度为 1536。如果在 MLP-Mixer 或者 ViT 中,其表达的含义是将 224 × 224 × 3 拆分为 32 × 32 个 patch,每个 patch 大小为 7 × 7 × 3 。然后将其进行展平成为一个 147维度的向量。经过一个线性层变为一个 1536 维度的向量。这其实就是等价于使用 1536 个 7 × 7 × 3的卷积核对自然图像进行卷积,且 stride = 7,padding = 0。这就将每个 patch 变为 1536 个特征值。
在 ConvMixer Layer 中,作者采用分离空间和通道维度的混合,即先通过一个 Depthwise Conv(即组数等于通道数的分组卷积),再通过一个 Pointwise Conv(即 1 × 1 卷积)。每个卷积之后是一个激活函数和激活后的 BatchNorm:
……(2)
……(3)
Motivation。正如Tolstikhinet al.(2021)所述,我们的架构基于混合的理念。特别地,我们对混合空间位置选择了深度卷积,对混合通道位置选择了点向卷积。之前工作的一个关键想法是,MLP和自我关注可以混合较远的空间位置,即它们可以具有任意大的感受野。值得注意的是:如果卷积核大小和 patch 个数一致,并且每个组之间共享卷积核权重,则 ConvMixer 就变为了 MLP-Mixer。
2 代码
import torch.nn as nn
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
def ConvMixer(dim, depth, kernel_size=9, patch_size=7, n_classes=1000):
return nn.Sequential(
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
nn.GELU(),
nn.BatchNorm2d(dim),
*[nn.Sequential(
Residual(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
nn.GELU(),
nn.BatchNorm2d(dim)
)),
nn.Conv2d(dim, dim, kernel_size=1),
nn.GELU(),
nn.BatchNorm2d(dim)
) for i in range(depth)],
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(dim, n_classes)
)
3 补充
分组卷积:
pytorch之torch.nn.Conv2d()函数详解_夏普通-CSDN博客_torch.nn.conv2d