apple公司提出的MobileVit模型,自己根据论文复现了一下模型

import torch.nn as nn


def conv3x3(in_ch: int, out_ch: int,group:int = 1, stride: int = 1):
    return nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=stride, padding=1,groups=group);

def conv1x1(in_ch: int, out_ch: int,group:int = 1):
    return nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1,groups=group);


class MobileNetV2_Block(nn.Module):
    def __init__(self,in_ch: int, out_ch: int,stride: int = 1):
        super(MobileNetV2_Block, self).__init__()
        self.conv1 = conv1x1(in_ch,in_ch * 6,group=in_ch * 6)
        self.conv2 = conv3x3(in_ch * 6,in_ch * 6,stride=stride)
        self.conv3 = conv1x1(in_ch * 6,out_ch,group=out_ch)

    def forward(self,x):
        x = self.conv1(x)
        x = nn.ReLU6(x)
        x = self.conv2(x)
        x = nn.ReLU6(x)
        x = self.conv3(x)
        return x

class MobileVit_Block(nn.Module):
    def __init__(self,in_ch: int, out_ch: int,d_model:int,nhead:int = 2,
                 num_encoder_layers:int = 6,num_decoder_layers:int = 6,
                 dim_feedforward:int = 2048):
        super(MobileVit_Block, self).__init__()
        self.d_model = d_model
        self.conv1 = conv3x3(in_ch,out_ch,group=out_ch)
        self.conv2 = conv1x1(out_ch,d_model)
        self.transformer = nn.Transformer(d_model=d_model,nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward)
        self.conv3 = conv1x1(d_model,out_ch)
        self.conv4 = conv1x1(out_ch, out_ch * 3)
        self.conv5 = conv3x3(out_ch * 3, out_ch)

    def forward(self, x):
        h,w = x.shape()[2:]
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.permute(0,2,3,1)
        x = x.view(-1,h*w,self.d_model)
        x = self.transformer(x)
        x = x.view(-1,h,w,self.d_model)
        x = x.permute(0, 3, 1, 2)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

class Mobile_Vit(nn.Module):
    def __init__(self,cnn_block,trans_block,in_ch, out_ch):
        super(Mobile_Vit, self).__init__()

        self.conv1 = conv3x3(in_ch,in_ch,stride=2)
        self.mv1 = self._make_cnn_layer(cnn_block,in_ch,in_ch)
        self.mv2 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
        self.mv3 = self._make_cnn_layer(cnn_block,in_ch,in_ch,2)
        self.mv4 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
        self.mt1 = self._make_trans_layer(trans_block,in_ch,in_ch,512,2)
        self.mv5 = self._make_cnn_layer(cnn_block,in_ch,in_ch,stride=2)
        self.mt2 = self._make_trans_layer(trans_block,in_ch,in_ch,512,4)
        self.mv6 = self._make_cnn_layer(cnn_block, in_ch, in_ch, stride=2)
        self.mt3 = self._make_trans_layer(trans_block,in_ch,in_ch,512,3)
        self.conv2 = conv1x1(in_ch, out_ch)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.flat = nn.Flatten()

    def _make_cnn_layer(block, in_ch, out_ch, blocks, stride=1):
        layers = []
        for i in range(0, blocks):
            layers.append(block(in_ch, out_ch, stride))
        return nn.Sequential(*layers)

    def _make_trans_layer(block, in_ch, out_ch,d_model, blocks):
        layers = []
        for i in range(0, blocks):
            layers.append(block(in_ch, out_ch,d_model))
        return nn.Sequential(*layers)

    def forward(self,x):
        x = self.conv1(x)
        x = self.mv1(x)
        x = self.mv2(x)
        x = self.mv3(x)
        x = self.mv4(x)
        x = self.mt1(x)
        x = self.mv5(x)
        x = self.mt2(x)
        x = self.mv6(x)
        x = self.mt3(x)
        x = self.conv2(x)
        x = self.flat(self.gap(x))
        return x

欢迎大家批评指正,论文链接:https://arxiv.org/pdf/2110.02178.pdf

 

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值