MLP Mixer pytorch实现 《MLP-Mixer: An all-MLP Architecture for Vision》

谷歌团队最新的文章《MLP-Mixer: An all-MLP Architecture for Vision》提出了一个无需卷积CNN,无需attention的多层感知机网络。文章试图证明Neither of them are necessary.
由于网络并不复杂,尝试了自己搭建了一下模型。

1.模型的规格参照paper中的Table 1,各种形参名称都是对应该表。
Specifications of the Mixer architecture used in this paper
2. 模型的整体框架如下图,其中搭建的时候分割成三部分:
两个全连接的mlp-block,
token-mixing和channel-mixing的Mixer Layer
图片分割和Global average pooling 以及分类头
模型整体结构
代码块如下,还是比较清晰明了的

import torch
import torch.nn as nn
from torchsummary import summary


# 这个是两层FC加一个激活函数的mlp block
# 因为有两个mixing,进出的维度都不变,只是中间全连接层的神经元数量不同
class mlp_block(nn.Module):
    def __init__(self, in_channels, mlp_dim, drop_ratio=0.):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(in_channels, mlp_dim),
            nn.GELU(),
            nn.Dropout(drop_ratio),
            nn.Linear(mlp_dim, in_channels),
            nn.Dropout(drop_ratio)
        )

    def forward(self, x):
        x = self.block(x)
        return x


class mlp_layer(nn.Module):
    def __init__(self, seq_length_s, hidden_size_c, mlp_dimension_dc, mlp_dimension_ds):
        super().__init__()
        self.ln = nn.LayerNorm(hidden_size_c)
        # 注意两个block分别作用于输入的行和列, 即SXC,所以in_channels不一样
        self.token_mixing = mlp_block(in_channels=seq_length_s, mlp_dim=mlp_dimension_dc)
        self.channel_mixing = mlp_block(in_channels=hidden_size_c, mlp_dim=mlp_dimension_ds)

    def forward(self, x):
        x1 = self.ln(x)
        x2 = x1.transpose(1, 2)  # 转置矩阵
        x3 = self.token_mixing(x2)
        x4 = x3.transpose(1, 2)

        y1 = x + x4  # skip-connection
        y2 = self.ln(y1)
        y3 = self.channel_mixing(y2)
        y = y1 + y3

        return y


# 按照paper中的 Table 1 来配置参数
class mlp_mixer(nn.Module):
    def __init__(self,
                 num_classes=1000,
                 img_size=224,
                 in_channels=3,
                 layer_num=12,
                 patch_size=32,
                 hidden_size_c=768,
                 seq_length_s=49,
                 mlp_dimension_dc=3072,
                 mlp_dimension_ds=384,
                 ):
        super().__init__()
        self.num_classes = num_classes
        self.img_size = img_size
        self.in_channels = in_channels
        self.patch_size = patch_size
        self.layer_num = layer_num
        self.hidden_size_c = hidden_size_c
        self.seq_length_s = seq_length_s
        self.mlp_dimension_dc = mlp_dimension_dc
        self.mlp_dimension_ds = mlp_dimension_ds

        self.ln = nn.LayerNorm(self.hidden_size_c)

        # 图片切割并做映射embedding,通过一个卷积实现
        self.proj = nn.Conv2d(self.in_channels, self.hidden_size_c, kernel_size=self.patch_size, stride=self.patch_size)

        # 添加多个mixer-layer
        self.mixer_layer = nn.ModuleList([])
        for _ in range(self.layer_num):
            self.mixer_layer.append(mlp_layer(seq_length_s, hidden_size_c, mlp_dimension_ds, mlp_dimension_dc))

        # 最后全连接的分类头
        self.linear_classifier_head = nn.Linear(hidden_size_c, num_classes)

# 定义正向传播过程
    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size and W == self.img_size, \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."

        # flatten: [B, C, H, W] -> [B, C, HW]  # 第二个维度上展平 刚好是高度维度
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        for mixer_layer in self.mixer_layer:
            x = mixer_layer(x)
        x = self.ln(x)
        x = x.mean(dim=1)  # Global average pooling
        x = self.linear_classifier_head(x)
        return x

# 参数初始化
def _init_mlp_mixer_weights(m):
    """
    MLP Mixer weight initialization
    :param m: module
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)


 # 不同配置
def mlp_mixer_s_32(num_classes: int = 1000):
    model = mlp_mixer(num_classes=num_classes,
                      img_size=224,
                      in_channels=3,
                      layer_num=8,
                      patch_size=32,
                      hidden_size_c=512,
                      seq_length_s=49,
                      mlp_dimension_dc=2048,
                      mlp_dimension_ds=3256)
    return model


def mlp_mixer_s_16(num_classes: int = 1000):
    model = mlp_mixer(num_classes=num_classes,
                      img_size=224,
                      in_channels=3,
                      layer_num=8,
                      patch_size=16,
                      hidden_size_c=512,
                      seq_length_s=196,
                      mlp_dimension_dc=2048,
                      mlp_dimension_ds=256,

                      )
    return model


def mlp_mixer_b_32(num_classes: int = 1000):
    model = mlp_mixer(num_classes=num_classes,
                      img_size=224,
                      in_channels=3,
                      layer_num=12,
                      patch_size=32,
                      hidden_size_c=768,
                      seq_length_s=49,
                      mlp_dimension_dc=3072,
                      mlp_dimension_ds=384,

                      )
    return model


def mlp_mixer_b_16(num_classes: int = 1000):
    model = mlp_mixer(num_classes=num_classes,
                      img_size=224,
                      in_channels=3,
                      layer_num=12,
                      patch_size=16,
                      hidden_size_c=768,
                      seq_length_s=196,
                      mlp_dimension_dc=3072,
                      mlp_dimension_ds=384,

                      )
    return model


def mlp_mixer_l_32(num_classes: int = 1000):
    model = mlp_mixer(num_classes=num_classes,
                      img_size=224,
                      in_channels=3,
                      layer_num=24,
                      patch_size=32,
                      hidden_size_c=1024,
                      seq_length_s=49,
                      mlp_dimension_dc=4096,
                      mlp_dimension_ds=512,

                      )
    return model


def mlp_mixer_l_16(num_classes: int = 1000):
    model = mlp_mixer(num_classes=num_classes,
                      img_size=224,
                      in_channels=3,
                      layer_num=24,
                      patch_size=16,
                      hidden_size_c=1024,
                      seq_length_s=196,
                      mlp_dimension_dc=4096,
                      mlp_dimension_ds=512,

                      )
    return model


def mlp_mixer_h_14(num_classes=1000):
    model = mlp_mixer(num_classes=num_classes,
                      img_size=224,
                      in_channels=3,
                      layer_num=32,
                      patch_size=14,
                      hidden_size_c=1280,
                      seq_length_s=256,
                      mlp_dimension_dc=5120,
                      mlp_dimension_ds=640,

                      )
    return model


# 测试用
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = mlp_mixer_s_16(num_classes=1000).to(device)
    summary(model, (3, 224, 224))

使用S/16 测试网络框架以及参数量结果如下,和paper中一致,模型搭建正确。

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 512, 14, 14]         393,728
         LayerNorm-2             [-1, 196, 512]           1,024
            Linear-3             [-1, 512, 256]          50,432
              GELU-4             [-1, 512, 256]               0
           Dropout-5             [-1, 512, 256]               0
            Linear-6             [-1, 512, 196]          50,372
           Dropout-7             [-1, 512, 196]               0
         mlp_block-8             [-1, 512, 196]               0
         LayerNorm-9             [-1, 196, 512]           1,024
           Linear-10            [-1, 196, 2048]       1,050,624
             GELU-11            [-1, 196, 2048]               0
          Dropout-12            [-1, 196, 2048]               0
           Linear-13             [-1, 196, 512]       1,049,088
          Dropout-14             [-1, 196, 512]               0
        mlp_block-15             [-1, 196, 512]               0
        mlp_layer-16             [-1, 196, 512]               0
        LayerNorm-17             [-1, 196, 512]           1,024
           Linear-18             [-1, 512, 256]          50,432
             GELU-19             [-1, 512, 256]               0
          Dropout-20             [-1, 512, 256]               0
           Linear-21             [-1, 512, 196]          50,372
          Dropout-22             [-1, 512, 196]               0
        mlp_block-23             [-1, 512, 196]               0
        LayerNorm-24             [-1, 196, 512]           1,024
           Linear-25            [-1, 196, 2048]       1,050,624
             GELU-26            [-1, 196, 2048]               0
          Dropout-27            [-1, 196, 2048]               0
           Linear-28             [-1, 196, 512]       1,049,088
          Dropout-29             [-1, 196, 512]               0
        mlp_block-30             [-1, 196, 512]               0
        mlp_layer-31             [-1, 196, 512]               0
        LayerNorm-32             [-1, 196, 512]           1,024
           Linear-33             [-1, 512, 256]          50,432
             GELU-34             [-1, 512, 256]               0
          Dropout-35             [-1, 512, 256]               0
           Linear-36             [-1, 512, 196]          50,372
          Dropout-37             [-1, 512, 196]               0
        mlp_block-38             [-1, 512, 196]               0
        LayerNorm-39             [-1, 196, 512]           1,024
           Linear-40            [-1, 196, 2048]       1,050,624
             GELU-41            [-1, 196, 2048]               0
          Dropout-42            [-1, 196, 2048]               0
           Linear-43             [-1, 196, 512]       1,049,088
          Dropout-44             [-1, 196, 512]               0
        mlp_block-45             [-1, 196, 512]               0
        mlp_layer-46             [-1, 196, 512]               0
        LayerNorm-47             [-1, 196, 512]           1,024
           Linear-48             [-1, 512, 256]          50,432
             GELU-49             [-1, 512, 256]               0
          Dropout-50             [-1, 512, 256]               0
           Linear-51             [-1, 512, 196]          50,372
          Dropout-52             [-1, 512, 196]               0
        mlp_block-53             [-1, 512, 196]               0
        LayerNorm-54             [-1, 196, 512]           1,024
           Linear-55            [-1, 196, 2048]       1,050,624
             GELU-56            [-1, 196, 2048]               0
          Dropout-57            [-1, 196, 2048]               0
           Linear-58             [-1, 196, 512]       1,049,088
          Dropout-59             [-1, 196, 512]               0
        mlp_block-60             [-1, 196, 512]               0
        mlp_layer-61             [-1, 196, 512]               0
        LayerNorm-62             [-1, 196, 512]           1,024
           Linear-63             [-1, 512, 256]          50,432
             GELU-64             [-1, 512, 256]               0
          Dropout-65             [-1, 512, 256]               0
           Linear-66             [-1, 512, 196]          50,372
          Dropout-67             [-1, 512, 196]               0
        mlp_block-68             [-1, 512, 196]               0
        LayerNorm-69             [-1, 196, 512]           1,024
           Linear-70            [-1, 196, 2048]       1,050,624
             GELU-71            [-1, 196, 2048]               0
          Dropout-72            [-1, 196, 2048]               0
           Linear-73             [-1, 196, 512]       1,049,088
          Dropout-74             [-1, 196, 512]               0
        mlp_block-75             [-1, 196, 512]               0
        mlp_layer-76             [-1, 196, 512]               0
        LayerNorm-77             [-1, 196, 512]           1,024
           Linear-78             [-1, 512, 256]          50,432
             GELU-79             [-1, 512, 256]               0
          Dropout-80             [-1, 512, 256]               0
           Linear-81             [-1, 512, 196]          50,372
          Dropout-82             [-1, 512, 196]               0
        mlp_block-83             [-1, 512, 196]               0
        LayerNorm-84             [-1, 196, 512]           1,024
           Linear-85            [-1, 196, 2048]       1,050,624
             GELU-86            [-1, 196, 2048]               0
          Dropout-87            [-1, 196, 2048]               0
           Linear-88             [-1, 196, 512]       1,049,088
          Dropout-89             [-1, 196, 512]               0
        mlp_block-90             [-1, 196, 512]               0
        mlp_layer-91             [-1, 196, 512]               0
        LayerNorm-92             [-1, 196, 512]           1,024
           Linear-93             [-1, 512, 256]          50,432
             GELU-94             [-1, 512, 256]               0
          Dropout-95             [-1, 512, 256]               0
           Linear-96             [-1, 512, 196]          50,372
          Dropout-97             [-1, 512, 196]               0
        mlp_block-98             [-1, 512, 196]               0
        LayerNorm-99             [-1, 196, 512]           1,024
          Linear-100            [-1, 196, 2048]       1,050,624
            GELU-101            [-1, 196, 2048]               0
         Dropout-102            [-1, 196, 2048]               0
          Linear-103             [-1, 196, 512]       1,049,088
         Dropout-104             [-1, 196, 512]               0
       mlp_block-105             [-1, 196, 512]               0
       mlp_layer-106             [-1, 196, 512]               0
       LayerNorm-107             [-1, 196, 512]           1,024
          Linear-108             [-1, 512, 256]          50,432
            GELU-109             [-1, 512, 256]               0
         Dropout-110             [-1, 512, 256]               0
          Linear-111             [-1, 512, 196]          50,372
         Dropout-112             [-1, 512, 196]               0
       mlp_block-113             [-1, 512, 196]               0
       LayerNorm-114             [-1, 196, 512]           1,024
          Linear-115            [-1, 196, 2048]       1,050,624
            GELU-116            [-1, 196, 2048]               0
         Dropout-117            [-1, 196, 2048]               0
          Linear-118             [-1, 196, 512]       1,049,088
         Dropout-119             [-1, 196, 512]               0
       mlp_block-120             [-1, 196, 512]               0
       mlp_layer-121             [-1, 196, 512]               0
       LayerNorm-122             [-1, 196, 512]           1,024
          Linear-123                 [-1, 1000]         513,000
================================================================
Total params: 18,528,264
Trainable params: 18,528,264
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 154.16
Params size (MB): 70.68
Estimated Total Size (MB): 225.42
----------------------------------------------------------------

Process finished with exit code 0
  • 3
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值