谷歌团队最新的文章《MLP-Mixer: An all-MLP Architecture for Vision》提出了一个无需卷积CNN,无需attention的多层感知机网络。文章试图证明Neither of them are necessary.
由于网络并不复杂,尝试了自己搭建了一下模型。
1.模型的规格参照paper中的Table 1,各种形参名称都是对应该表。
2. 模型的整体框架如下图,其中搭建的时候分割成三部分:
两个全连接的mlp-block,
token-mixing和channel-mixing的Mixer Layer
图片分割和Global average pooling 以及分类头
代码块如下,还是比较清晰明了的
import torch
import torch.nn as nn
from torchsummary import summary
# 这个是两层FC加一个激活函数的mlp block
# 因为有两个mixing,进出的维度都不变,只是中间全连接层的神经元数量不同
class mlp_block(nn.Module):
def __init__(self, in_channels, mlp_dim, drop_ratio=0.):
super().__init__()
self.block = nn.Sequential(
nn.Linear(in_channels, mlp_dim),
nn.GELU(),
nn.Dropout(drop_ratio),
nn.Linear(mlp_dim, in_channels),
nn.Dropout(drop_ratio)
)
def forward(self, x):
x = self.block(x)
return x
class mlp_layer(nn.Module):
def __init__(self, seq_length_s, hidden_size_c, mlp_dimension_dc, mlp_dimension_ds):
super().__init__()
self.ln = nn.LayerNorm(hidden_size_c)
# 注意两个block分别作用于输入的行和列, 即SXC,所以in_channels不一样
self.token_mixing = mlp_block(in_channels=seq_length_s, mlp_dim=mlp_dimension_dc)
self.channel_mixing = mlp_block(in_channels=hidden_size_c, mlp_dim=mlp_dimension_ds)
def forward(self, x):
x1 = self.ln(x)
x2 = x1.transpose(1, 2) # 转置矩阵
x3 = self.token_mixing(x2)
x4 = x3.transpose(1, 2)
y1 = x + x4 # skip-connection
y2 = self.ln(y1)
y3 = self.channel_mixing(y2)
y = y1 + y3
return y
# 按照paper中的 Table 1 来配置参数
class mlp_mixer(nn.Module):
def __init__(self,
num_classes=1000,
img_size=224,
in_channels=3,
layer_num=12,
patch_size=32,
hidden_size_c=768,
seq_length_s=49,
mlp_dimension_dc=3072,
mlp_dimension_ds=384,
):
super().__init__()
self.num_classes = num_classes
self.img_size = img_size
self.in_channels = in_channels
self.patch_size = patch_size
self.layer_num = layer_num
self.hidden_size_c = hidden_size_c
self.seq_length_s = seq_length_s
self.mlp_dimension_dc = mlp_dimension_dc
self.mlp_dimension_ds = mlp_dimension_ds
self.ln = nn.LayerNorm(self.hidden_size_c)
# 图片切割并做映射embedding,通过一个卷积实现
self.proj = nn.Conv2d(self.in_channels, self.hidden_size_c, kernel_size=self.patch_size, stride=self.patch_size)
# 添加多个mixer-layer
self.mixer_layer = nn.ModuleList([])
for _ in range(self.layer_num):
self.mixer_layer.append(mlp_layer(seq_length_s, hidden_size_c, mlp_dimension_ds, mlp_dimension_dc))
# 最后全连接的分类头
self.linear_classifier_head = nn.Linear(hidden_size_c, num_classes)
# 定义正向传播过程
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size and W == self.img_size, \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
# flatten: [B, C, H, W] -> [B, C, HW] # 第二个维度上展平 刚好是高度维度
# transpose: [B, C, HW] -> [B, HW, C]
x = self.proj(x).flatten(2).transpose(1, 2)
for mixer_layer in self.mixer_layer:
x = mixer_layer(x)
x = self.ln(x)
x = x.mean(dim=1) # Global average pooling
x = self.linear_classifier_head(x)
return x
# 参数初始化
def _init_mlp_mixer_weights(m):
"""
MLP Mixer weight initialization
:param m: module
"""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
# 不同配置
def mlp_mixer_s_32(num_classes: int = 1000):
model = mlp_mixer(num_classes=num_classes,
img_size=224,
in_channels=3,
layer_num=8,
patch_size=32,
hidden_size_c=512,
seq_length_s=49,
mlp_dimension_dc=2048,
mlp_dimension_ds=3256)
return model
def mlp_mixer_s_16(num_classes: int = 1000):
model = mlp_mixer(num_classes=num_classes,
img_size=224,
in_channels=3,
layer_num=8,
patch_size=16,
hidden_size_c=512,
seq_length_s=196,
mlp_dimension_dc=2048,
mlp_dimension_ds=256,
)
return model
def mlp_mixer_b_32(num_classes: int = 1000):
model = mlp_mixer(num_classes=num_classes,
img_size=224,
in_channels=3,
layer_num=12,
patch_size=32,
hidden_size_c=768,
seq_length_s=49,
mlp_dimension_dc=3072,
mlp_dimension_ds=384,
)
return model
def mlp_mixer_b_16(num_classes: int = 1000):
model = mlp_mixer(num_classes=num_classes,
img_size=224,
in_channels=3,
layer_num=12,
patch_size=16,
hidden_size_c=768,
seq_length_s=196,
mlp_dimension_dc=3072,
mlp_dimension_ds=384,
)
return model
def mlp_mixer_l_32(num_classes: int = 1000):
model = mlp_mixer(num_classes=num_classes,
img_size=224,
in_channels=3,
layer_num=24,
patch_size=32,
hidden_size_c=1024,
seq_length_s=49,
mlp_dimension_dc=4096,
mlp_dimension_ds=512,
)
return model
def mlp_mixer_l_16(num_classes: int = 1000):
model = mlp_mixer(num_classes=num_classes,
img_size=224,
in_channels=3,
layer_num=24,
patch_size=16,
hidden_size_c=1024,
seq_length_s=196,
mlp_dimension_dc=4096,
mlp_dimension_ds=512,
)
return model
def mlp_mixer_h_14(num_classes=1000):
model = mlp_mixer(num_classes=num_classes,
img_size=224,
in_channels=3,
layer_num=32,
patch_size=14,
hidden_size_c=1280,
seq_length_s=256,
mlp_dimension_dc=5120,
mlp_dimension_ds=640,
)
return model
# 测试用
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = mlp_mixer_s_16(num_classes=1000).to(device)
summary(model, (3, 224, 224))
使用S/16 测试网络框架以及参数量结果如下,和paper中一致,模型搭建正确。
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 512, 14, 14] 393,728
LayerNorm-2 [-1, 196, 512] 1,024
Linear-3 [-1, 512, 256] 50,432
GELU-4 [-1, 512, 256] 0
Dropout-5 [-1, 512, 256] 0
Linear-6 [-1, 512, 196] 50,372
Dropout-7 [-1, 512, 196] 0
mlp_block-8 [-1, 512, 196] 0
LayerNorm-9 [-1, 196, 512] 1,024
Linear-10 [-1, 196, 2048] 1,050,624
GELU-11 [-1, 196, 2048] 0
Dropout-12 [-1, 196, 2048] 0
Linear-13 [-1, 196, 512] 1,049,088
Dropout-14 [-1, 196, 512] 0
mlp_block-15 [-1, 196, 512] 0
mlp_layer-16 [-1, 196, 512] 0
LayerNorm-17 [-1, 196, 512] 1,024
Linear-18 [-1, 512, 256] 50,432
GELU-19 [-1, 512, 256] 0
Dropout-20 [-1, 512, 256] 0
Linear-21 [-1, 512, 196] 50,372
Dropout-22 [-1, 512, 196] 0
mlp_block-23 [-1, 512, 196] 0
LayerNorm-24 [-1, 196, 512] 1,024
Linear-25 [-1, 196, 2048] 1,050,624
GELU-26 [-1, 196, 2048] 0
Dropout-27 [-1, 196, 2048] 0
Linear-28 [-1, 196, 512] 1,049,088
Dropout-29 [-1, 196, 512] 0
mlp_block-30 [-1, 196, 512] 0
mlp_layer-31 [-1, 196, 512] 0
LayerNorm-32 [-1, 196, 512] 1,024
Linear-33 [-1, 512, 256] 50,432
GELU-34 [-1, 512, 256] 0
Dropout-35 [-1, 512, 256] 0
Linear-36 [-1, 512, 196] 50,372
Dropout-37 [-1, 512, 196] 0
mlp_block-38 [-1, 512, 196] 0
LayerNorm-39 [-1, 196, 512] 1,024
Linear-40 [-1, 196, 2048] 1,050,624
GELU-41 [-1, 196, 2048] 0
Dropout-42 [-1, 196, 2048] 0
Linear-43 [-1, 196, 512] 1,049,088
Dropout-44 [-1, 196, 512] 0
mlp_block-45 [-1, 196, 512] 0
mlp_layer-46 [-1, 196, 512] 0
LayerNorm-47 [-1, 196, 512] 1,024
Linear-48 [-1, 512, 256] 50,432
GELU-49 [-1, 512, 256] 0
Dropout-50 [-1, 512, 256] 0
Linear-51 [-1, 512, 196] 50,372
Dropout-52 [-1, 512, 196] 0
mlp_block-53 [-1, 512, 196] 0
LayerNorm-54 [-1, 196, 512] 1,024
Linear-55 [-1, 196, 2048] 1,050,624
GELU-56 [-1, 196, 2048] 0
Dropout-57 [-1, 196, 2048] 0
Linear-58 [-1, 196, 512] 1,049,088
Dropout-59 [-1, 196, 512] 0
mlp_block-60 [-1, 196, 512] 0
mlp_layer-61 [-1, 196, 512] 0
LayerNorm-62 [-1, 196, 512] 1,024
Linear-63 [-1, 512, 256] 50,432
GELU-64 [-1, 512, 256] 0
Dropout-65 [-1, 512, 256] 0
Linear-66 [-1, 512, 196] 50,372
Dropout-67 [-1, 512, 196] 0
mlp_block-68 [-1, 512, 196] 0
LayerNorm-69 [-1, 196, 512] 1,024
Linear-70 [-1, 196, 2048] 1,050,624
GELU-71 [-1, 196, 2048] 0
Dropout-72 [-1, 196, 2048] 0
Linear-73 [-1, 196, 512] 1,049,088
Dropout-74 [-1, 196, 512] 0
mlp_block-75 [-1, 196, 512] 0
mlp_layer-76 [-1, 196, 512] 0
LayerNorm-77 [-1, 196, 512] 1,024
Linear-78 [-1, 512, 256] 50,432
GELU-79 [-1, 512, 256] 0
Dropout-80 [-1, 512, 256] 0
Linear-81 [-1, 512, 196] 50,372
Dropout-82 [-1, 512, 196] 0
mlp_block-83 [-1, 512, 196] 0
LayerNorm-84 [-1, 196, 512] 1,024
Linear-85 [-1, 196, 2048] 1,050,624
GELU-86 [-1, 196, 2048] 0
Dropout-87 [-1, 196, 2048] 0
Linear-88 [-1, 196, 512] 1,049,088
Dropout-89 [-1, 196, 512] 0
mlp_block-90 [-1, 196, 512] 0
mlp_layer-91 [-1, 196, 512] 0
LayerNorm-92 [-1, 196, 512] 1,024
Linear-93 [-1, 512, 256] 50,432
GELU-94 [-1, 512, 256] 0
Dropout-95 [-1, 512, 256] 0
Linear-96 [-1, 512, 196] 50,372
Dropout-97 [-1, 512, 196] 0
mlp_block-98 [-1, 512, 196] 0
LayerNorm-99 [-1, 196, 512] 1,024
Linear-100 [-1, 196, 2048] 1,050,624
GELU-101 [-1, 196, 2048] 0
Dropout-102 [-1, 196, 2048] 0
Linear-103 [-1, 196, 512] 1,049,088
Dropout-104 [-1, 196, 512] 0
mlp_block-105 [-1, 196, 512] 0
mlp_layer-106 [-1, 196, 512] 0
LayerNorm-107 [-1, 196, 512] 1,024
Linear-108 [-1, 512, 256] 50,432
GELU-109 [-1, 512, 256] 0
Dropout-110 [-1, 512, 256] 0
Linear-111 [-1, 512, 196] 50,372
Dropout-112 [-1, 512, 196] 0
mlp_block-113 [-1, 512, 196] 0
LayerNorm-114 [-1, 196, 512] 1,024
Linear-115 [-1, 196, 2048] 1,050,624
GELU-116 [-1, 196, 2048] 0
Dropout-117 [-1, 196, 2048] 0
Linear-118 [-1, 196, 512] 1,049,088
Dropout-119 [-1, 196, 512] 0
mlp_block-120 [-1, 196, 512] 0
mlp_layer-121 [-1, 196, 512] 0
LayerNorm-122 [-1, 196, 512] 1,024
Linear-123 [-1, 1000] 513,000
================================================================
Total params: 18,528,264
Trainable params: 18,528,264
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 154.16
Params size (MB): 70.68
Estimated Total Size (MB): 225.42
----------------------------------------------------------------
Process finished with exit code 0