轻量级网络之mobilenet_v1 pytorch实现
前言:前面讲解了mobilenet 实现在移动端或者嵌入式中的轻量级网络,本文使用pytorch 搭建mobilenet_v1网络。
一、Mobilenet_v1 网络结构
1.Mobilenet_v1 网络结构如图所示
由此我们可以得出mobilenet_v1的网络结构由标准卷积、深度可分离卷积、平均池化、全连接层组成。
2.标准卷积模块
由Conv + BN + Relu
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes: int, out_planes: int, kernel_size: int = 3, stride: int = 1, groups: int = 1,dilation: int = 1):
padding = (kernel_size - 1) // 2 * dilation
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True)
)
self.out_channels = out_planes
Conv3x3BNReLU = ConvBNReLU
3.深度可分离卷积模块
由深度卷积(DW Conv + BN+Relu) 和逐点卷积 (PW Conv + BN + Relu )组成。
class DWConvBNReLU(nn.Sequential):
def __init__(self,in_planes: int,out_planes: int,kernel_size: int = 3,stride: int = 1,groups: int = in_planes):
padding = (kernel_size -1) // 2
super(DWConvBNReLU,self).__init__(
nn.Conv2d(in_planes, out_planes,kernel_size,stride,padding,groups = groups,bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True)
)
self.out_channels = out_planes
DWConv3x3BNReLU = DWConvBNReLU
class PWConvBNReLU(nn.Sequential):
def __init__(self,in_planes: int,out_planes: int,kernel_size: int = 1,stride: int = 1,groups: int = 1):
padding = (kernel_size -1) //2
super(PWConvBNReLU,self).__init__(
nn.Conv2d(in_planes,out_planes,kernel_size,stride,groups=groups,bias=False),
nn.BatchNorm2d(out_planes)
nn.ReLU(inplace=True)
)
self.out_channels = out_planes
PWConv1x1BNReLU = PWConvBNReLU
4.平均池化模块
x = nn.functional.AvgPool2d(kernel_size=7,stride =1)
5.全连接层模块
x = torch.flatten(x,1) //将平均池化的特征图拉成一维的向量
model.fc = nn.linear(1024,1000)
二、完整的代码实现
import torch
import torch.nn as nn
import torchvision
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes: int, out_planes: int, kernel_size: int = 3, stride: int = 1, groups: int = 1,dilation: int = 1):
padding = (kernel_size - 1) // 2 * dilation
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True)
)
self.out_channels = out_planes
Conv3x3BNReLU = ConvBNReLU
class DWConvBNReLU(nn.Sequential):
def __init__(self,in_planes: int,out_planes: int,kernel_size: int = 3,stride: int = 1,groups: int = in_planes):
padding = (kernel_size -1) // 2
super(DWConvBNReLU,self).__init__(
nn.Conv2d(in_planes, out_planes,kernel_size,stride,padding,groups = groups,bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True)
)
self.out_channels = out_planes
DWConv3x3BNReLU = DWConvBNReLU
class PWConvBNReLU(nn.Sequential):
def __init__(self,in_planes: int,out_planes: int,kernel_size: int = 1,stride: int = 1,groups: int = 1):
padding = (kernel_size -1) //2
super(PWConvBNReLU,self).__init__(
nn.Conv2d(in_planes,out_planes,kernel_size,stride,groups=groups,bias=False),
nn.BatchNorm2d(out_planes)
nn.ReLU(inplace=True)
)
self.out_channels = out_planes
PWConv1x1BNReLU = PWConvBNReLU
class mobilenet_v1(nn.module):
# num_class 为类别个数,width_factor为卷积核扩充因子
def.__init__(self,num_classes = num_class, width_factor):
super(mobilenet_v1,self).__init__()
# Conv / s2 3 × 3 × 3 × 32 224 × 224 × 3 标准卷积,步长为2,特征图尺寸减半
self.Conv1 = Conv3x3BNReLU(in_planes=3,out_planes=32,stride=2,groups=1)
# Conv dw / s1 3 × 3 × 32 dw 112 × 112 × 32 深度卷积,步长为1,特征图尺寸不变
self.DWConv2 = DWConv3x3BNReLU(in_planes=32,out_planes=32,stride=1,groups=32)
# Conv / s1 1 × 1 × 32 × 64 112 × 112 × 32 逐点卷积,步长为1,特征图尺寸不变
self.PWConv3 = PWConv1x1BNReLU(in_planes=32,out_planes=64,stride=1,groups=1)
# Conv dw / s2 3 × 3 × 64 dw 112 × 112 × 64 深度卷积,步长为2,特征图尺寸减半
self.DWConv4 = DWConv3x3BNReLU(in_planes=64,out_planes=64,stride=2,groups=64)
# Conv / s1 1 × 1 × 64 × 128 56 × 56 × 64 标准卷积,步长为1,特征图尺寸不变
self.PWConv5 = DWConv3x3BNReLU(in_planes=64,out_planes=128,stride=1,groups=1)
# Conv dw / s1 3 × 3 × 128 dw 56 × 56 × 128 深度卷积,步长为1,特征图尺寸不变
self.DWConv6 = DWConv3x3BNReLU(in_planes=128,out_planes=128,stride=1,groups=128)
# Conv / s1 1 × 1 × 128 × 128 56 × 56 × 128 逐点卷积,步长为1,特征图尺寸不变
self.PWConv7 = PWConv1x1BNReLU(in_planes=128,out_planes=128,stride=1,groups=1)
# Conv dw / s2 3 × 3 × 128 dw 56 × 56 × 128 深度卷积,步长为2,特征图尺寸不变
self.DWConv8 = DWConv3x3BNReLU(in_planes=128,out_planes=128,stride=2,groups=128)
# Conv / s1 1 × 1 × 128 × 256 28 × 28 × 128 逐点卷积,步长为1,特征图尺寸减半
self.PWConv9 = PWConv1x1BNReLU(in_planes=128,out_planes=256,stride=1,groups=1)
# Conv dw / s1 3 × 3 × 256 dw 28 × 28 × 256 深度卷积,步长为1,特征图尺寸不变
self.DWConv10 = DWConv3x3BNReLU(in_planes=256,out_planes=256,stride=1,groups=256)
# Conv / s1 1 × 1 × 256 × 256 28 × 28 × 256
self.PWConv11 = PWConv1x1BNReLU(in_planes=256,out_planes=256,stride=1,groups=1)
# Conv dw / s2 3 × 3 × 256 dw 28 × 28 × 256
self.DWConv12 = DWConv3x3BNReLU(in_planes=256,out_planes=256,stride=1,groups=256)
# Conv / s1 1 × 1 × 256 × 512 14 × 14 × 256
self.PWConv13 = PWConv3x3BNReLU(in_planes=256,out_planes=512,stride=1,groups=1)
# Conv dw / s1 3 × 3 × 512 dw 14 × 14 × 512 Conv / s1 1 × 1 × 512 × 512 14 × 14 × 512 x5
self.DWConv14 = DWConv3x3BNReLU(in_planes=512,out_planes=512,stride=1,groups=512)
self.PWConv15 = PWConv1x1BNReLU(in_planes=512,out_planes=512,stride=1,groups=1)
self.DWConv16 = DWConv3x3BNReLU(in_planes=512,out_planes=512,stride=1,groups=512)
self.PWConv17 = PWConv1x1BNReLU(in_planes=512,out_planes=512,stride=1,groups=1)
self.DWConv18 = DWConv3x3BNReLU(in_planes=512,out_planes=512,stride=1,groups=512)
self.PWConv19 = PWConv1x1BNReLU(in_planes=512,out_planes=512,stride=1,groups=1)
self.DWConv20 = DWConv3x3BNReLU(in_planes=512,out_planes=512,stride=1,groups=512)
self.PWConv21 = PWConv1x1BNReLU(in_planes=512,out_planes=512,stride=1,groups=1)
self.DWConv22 = DWConv3x3BNReLU(in_planes=512,out_planes=512,stride=1,groups=512)
self.PWConv23 = PWConv1x1BNReLU(in_planes=512,out_planes=512,stride=1,groups=1)
# Conv dw / s2 3 × 3 × 512 dw 14 × 14 × 512
self.DWConv24 = DWConv3x3BNReLU(in_planes=512,out_planes=512,stride=2,groups=512)
# Conv / s1 1 × 1 × 512 × 1024 7 × 7 × 512
self.PWConv25 = PWConv3x3BNReLU(in_planes=512,out_planes=1024,stride=1,groups=1)
# Conv dw / s2 3 × 3 × 1024 dw 7 × 7 × 1024
self.DWConv26 = DWConv3x3BNReLU(in_planes=1024,out_planes=1024,stride=1,groups=1024)
# Conv / s1 1 × 1 × 1024 × 1024 7 × 7 × 1024
self.PWConv27 = PWConv3x3BNReLU(in_planes=1024,out_planes=1024,stride=1,groups=1)
# Avg Pool / s1 Pool 7 × 7 7 × 7 × 1024
self.avgpool = nn.AvgPool2d(kernel_size=7,stride =1)
# FC / s1 1024 × 1000 1 × 1 × 1024
# Softmax / s1 Classifier 1 × 1 × 1000
self.dropout = nn.Dropout(p=0.2)
self.fc = nn.Linear(in_planes=1024,out_planes=1000)
self.init_param()
# 初始化参数
def init_param():
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
# nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
def forward(self,x):
x = self.Conv1(x)
x = self.DWConv2(x)
x = self.PWConv3(x)
x = self.DWConv4(x)
x = self.PWConv5(x)
x = self.DWConv6(x)
x = self.PWConv7(x)
x = self.DWConv8(x)
x = self.PWConv9(x)
x = self.DWConv10(x)
x = self.PWConv11(x)
x = self.DWConv12(x)
x = self.PWConv13(x)
x = self.DWConv14(x)
x = self.PWConv15(x)
x = self.DWConv16(x)
x = self.PWConv17(x)
x = self.DWConv18(x)
x = self.PWConv19(x)
x = self.DWConv20(x)
x = self.PWConv21(x)
x = self.DWConv22(x)
x = self.PWConv23(x)
x = self.DWConv24(x)
x = self.PWConv25(x)
x = self.DWConv26(x)
x = self.PWConv27(x)
x = self.avgpool(x)
x = x.view(-1,1024)
x = self.dropout(x)
x = self.fc(x)
return x
三、 测试
1、输入命令
python mobilenet_v1.py
2.测试结果
torch.Size([1, 32, 112, 112])
torch.Size([1, 32, 112, 112])
torch.Size([1, 64, 112, 112])
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 56, 56])
torch.Size([1, 128, 56, 56])
torch.Size([1, 128, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 256, 28, 28])
torch.Size([1, 256, 28, 28])
torch.Size([1, 256, 28, 28])
torch.Size([1, 256, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 7, 7])
torch.Size([1, 1024, 7, 7])
torch.Size([1, 1024, 7, 7])
torch.Size([1, 1024, 7, 7])
torch.Size([1, 1024, 1, 1])
torch.Size([1, 1024])
torch.Size([1, 1024])
torch.Size([1, 3])
torch.Size([1, 3])