Squeeze Net代码与解析

本文介绍了一种名为SqueezeNet的高效卷积神经网络,重点讲解了FireBlock结构,包括如何通过1x1和3x3卷积的结合、减小通道数以及推迟下采样来提升模型效率。该网络在保持高精度的同时,减少了参数和计算量,适合资源有限的设备上应用。
摘要由CSDN通过智能技术生成

参考链接:https://zhuanlan.zhihu.com/p/49465950

主要的三个策略:

  1. 部分使用1x1卷积替换3x3卷积
  2. 减少3x3卷积的输入通道数量。
  3. 将网络下采样的时机推迟到网络后面,因为在其他情况不变下,尺寸大的特征图具有更高的分类准确度。

主要Block为Fire Block:

在这里插入图片描述
在这里插入图片描述

网络结构:

在这里插入图片描述

代码:

import torch
import torch.nn as nn

class fire(nn.Module):
    def __init__(self,in_channel, out_channel):
        super(fire, self).__init__()
        self.conv1 = nn.Conv2d(in_channel,out_channel//8,kernel_size=1)
        self.conv2_1 = nn.Conv2d(out_channel//8,out_channel//2,kernel_size=1)
        self.conv2_2 = nn.Conv2d(out_channel//8,out_channel//2,kernel_size=3,padding= 3//2)
        self.BN1 = nn.BatchNorm2d(out_channel//4)

        self.ReLU = nn.ReLU()

    def forward(self,x):
        out = self.ReLU(self.BN1(self.conv1(x)))
        out1 = self.conv2_1(out)
        out2 = self.conv2_2(out)
        out  = self.ReLU(torch.cat([out1,out2],1))
        return out

class SQUEEZE(nn.Module):
    def __init__(self,in_channel, classses):
        super(SQUEEZE, self).__init__()
        channels = [96,128,128,256,256,384,384,512,512]
        self.conv1 = nn.Conv2d(in_channel,channels[0],7,2,padding=7//2)
        self.pool1 = nn.MaxPool2d(kernel_size=3,stride=2)
        self.BN1 = nn.BatchNorm2d(channels[0])
        self.block = fire
        self.block1 = nn.ModuleList([])
        for i in range(7):
            self.block1.append(self.block(in_channel = channels[i],out_channel = channels[i+1]))
            if i in [3,6]:
                self.block1.append(nn.MaxPool2d(kernel_size=3,stride=2))
        self.block1.append(self.block(channels[-2],channels[-1]))
        self.conv10 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Conv2d(channels[-1],classses,kernel_size=1,stride=1),
            nn.ReLU())

        self.pool2 = nn.MaxPool2d(kernel_size=13)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.BN1(x)
        for block in self.block1:
            x = block(x)
        x = self.conv10(x)
        out = self.pool2(x)
        return out

if __name__ == '__main__':
    input = torch.empty(1,3,224,224)
    m = SQUEEZE(3,10)
    out = m(input)
    print(out)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值