MobileNet V3代码

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class h_swish(nn.Module):
    def forward(self,x):
        return x*F.relu6(x+3)/6

class swish(nn.Module):
    def forward(self,x):
        return x*F.sigmoid(x)

class h_sigmoid(nn.Module):
    def forward(self,x):
        return F.relu6(x+3)/6

def _make_divisor(ch, divisor, min_ch = None):
    if not min_ch:
        min_ch = divisor
    new_ch = max(min_ch,int(ch+divisor/2)//divisor*divisor)
    if new_ch < 0.9*ch:
        new_ch += divisor
    return new_ch

class SE_module(nn.Module):
    def __init__(self,inchannel):
        super(SE_module, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(inchannel,inchannel//4,1,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(inchannel//4,inchannel,1),
            h_sigmoid()
        )
    def forward(self,x):
        mul = self.se(x)
        return x * mul

class bneck(nn.Module):
    def __init__(self,inchannel,outchannel,hidden_channel,nonlinear,stride,SE=False):
        super(bneck, self).__init__()
        self.shortcut = True if stride == 1 and inchannel == outchannel else False

        layers = []
        if inchannel != hidden_channel:
            layers.extend([
                nn.Conv2d(inchannel,hidden_channel,1),
                nn.BatchNorm2d(hidden_channel),
                nonlinear()
            ])
        layers.extend([
            nn.Conv2d(hidden_channel,hidden_channel,3,1,1,groups = hidden_channel),
            nn.BatchNorm2d(hidden_channel),
            nonlinear()
        ])
        self.conv1 = nn.Sequential(*layers)
        self.se = SE_module(hidden_channel)
        self.conv2 = nn.Sequential(
            nn.Conv2d(hidden_channel,outchannel,1,1),
            nn.BatchNorm2d(outchannel)
        )
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(self.se(x))
        return x


class MobileNet_V3(nn.Module):
    def __init__(self, setting, inchannel, classes, alpha = 0.2, round_nearest = 8):
        super(MobileNet_V3, self).__init__()
        input_channel = _make_divisor(16*alpha,round_nearest)
        last_channel = _make_divisor(setting[-1][3]*alpha,round_nearest)

        self.HS = h_swish
        self.RE = nn.ReLU

        self.conv1 = nn.Sequential(nn.Conv2d(inchannel,input_channel,3,2,1),
                                   nn.BatchNorm2d(input_channel),
                                   self.HS()
                                   )

        self.block = bneck
        self.blocks = nn.ModuleList([])
        self.nonlin = self.HS
        for _, kernel_size, hidden, out_channels, SE, nonlinear, stride in setting:
            self.nonlin = self.RE if nonlinear == 'RE' else self.HS
            self.hidden = _make_divisor(hidden*alpha,round_nearest)
            out_channels = _make_divisor(out_channels*alpha,round_nearest)
            self.blocks.append(self.block(input_channel, out_channels, self.hidden, self.nonlin, stride, SE))
            input_channel = out_channels

        self.conv2 = nn.Sequential(
            nn.Conv2d(input_channel, self.hidden, 1,1),
            nn.BatchNorm2d(self.hidden),
            SE_module(self.hidden),
            self.HS()
        )

        self.pool = nn.AdaptiveAvgPool2d((1))
        self.conv3 = nn.Sequential(
            nn.Conv2d(self.hidden, 1024, 1, 1),
            self.HS(),
            nn.Dropout(0.2),
            nn.Conv2d(1024,classes,1,1)
        )

        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode = 'fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m,nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self,x):
        x = self.conv1(x)
        for block in self.blocks:
            x = block(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.conv3(x)
        return x



def MobileNet_V3_large(inchannel,classes):
    setting = [
        [16, 3, 16, 16, False, 'RE', 1],
        [16, 3, 64, 24, False, 'RE', 2],
        [24, 3, 72, 24, False, 'RE', 1],
        [24, 5, 72, 40, True, 'RE', 2],
        [40, 5, 120, 40, True, 'RE', 1],
        [40, 5, 120, 40, True, 'RE', 1],
        [40, 3, 240, 80, False, 'HS', 2],
        [80, 3, 200, 80, False, 'HS', 1],
        [80, 3, 184, 80, False, 'HS', 1],
        [80, 3, 184, 80, False, 'HS', 1],
        [80, 3, 480, 112, True, 'HS', 1],
        [112, 3, 672, 112, True, 'HS', 1],
        [112, 5, 672, 160, True, 'HS', 2],
        [160, 5, 960, 160, True, 'HS', 1],
        [160, 5, 960, 160, True, 'HS', 1]
    ]
    return MobileNet_V3(setting,inchannel,classes)

def MobileNet_V3_small(inchannel,classes):
    setting = [
        [16, 3, 16, 16, True, 'RE', 2],
        [16, 3, 72, 24, False, 'RE', 2],
        [24, 3, 88, 24, False, 'RE', 1],
        [24, 5, 96, 40, True, 'HS', 2],
        [40, 5, 240, 40, True, 'HS',1],
        [40, 5, 240, 40, True, 'HS', 1],
        [40, 5, 120, 48, True, 'HS',1],
        [48, 5, 144, 48, True, 'HS', 1],
        [48, 5, 288, 96, True, 'HS',2],
        [96, 5, 576, 96, True, 'HS',1],
        [96, 5, 576, 96, True, 'HS',1]

    ]
    return MobileNet_V3(setting,inchannel,classes)

if __name__ == '__main__':
    input = torch.empty(1,3,224,224)
    m = MobileNet_V3_small(3,10)
    out = m(input)
    print(out)





  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值