参考链接:https://zhuanlan.zhihu.com/p/49465950
主要的三个策略:
- 部分使用1x1卷积替换3x3卷积
- 减少3x3卷积的输入通道数量。
- 将网络下采样的时机推迟到网络后面,因为在其他情况不变下,尺寸大的特征图具有更高的分类准确度。
主要Block为Fire Block:
网络结构:
代码:
import torch
import torch.nn as nn
class fire(nn.Module):
def __init__(self,in_channel, out_channel):
super(fire, self).__init__()
self.conv1 = nn.Conv2d(in_channel,out_channel//8,kernel_size=1)
self.conv2_1 = nn.Conv2d(out_channel//8,out_channel//2,kernel_size=1)
self.conv2_2 = nn.Conv2d(out_channel//8,out_channel//2,kernel_size=3,padding= 3//2)
self.BN1 = nn.BatchNorm2d(out_channel//4)
self.ReLU = nn.ReLU()
def forward(self,x):
out = self.ReLU(self.BN1(self.conv1(x)))
out1 = self.conv2_1(out)
out2 = self.conv2_2(out)
out = self.ReLU(torch.cat([out1,out2],1))
return out
class SQUEEZE(nn.Module):
def __init__(self,in_channel, classses):
super(SQUEEZE, self).__init__()
channels = [96,128,128,256,256,384,384,512,512]
self.conv1 = nn.Conv2d(in_channel,channels[0],7,2,padding=7//2)
self.pool1 = nn.MaxPool2d(kernel_size=3,stride=2)
self.BN1 = nn.BatchNorm2d(channels[0])
self.block = fire
self.block1 = nn.ModuleList([])
for i in range(7):
self.block1.append(self.block(in_channel = channels[i],out_channel = channels[i+1]))
if i in [3,6]:
self.block1.append(nn.MaxPool2d(kernel_size=3,stride=2))
self.block1.append(self.block(channels[-2],channels[-1]))
self.conv10 = nn.Sequential(
nn.Dropout(0.5),
nn.Conv2d(channels[-1],classses,kernel_size=1,stride=1),
nn.ReLU())
self.pool2 = nn.MaxPool2d(kernel_size=13)
def forward(self,x):
x = self.conv1(x)
x = self.pool1(x)
x = self.BN1(x)
for block in self.block1:
x = block(x)
x = self.conv10(x)
out = self.pool2(x)
return out
if __name__ == '__main__':
input = torch.empty(1,3,224,224)
m = SQUEEZE(3,10)
out = m(input)
print(out)