(big佬)(mobilenet系列简介)睿智的目标检测47——Keras 利用mobilenet系列(v1,v2,v3)搭建yolov4目标检测平台

原文链接:https://blog.csdn.net/weixin_44791964/article/details/107359153?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522163946123316780264096618%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=163946123316780264096618&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2blogfirst_rank_v2~rank_v29-5-107359153.pc_v2_rank_blog_default&utm_term=mobilenet&spm=1018.2226.3001.4450#amobilenetV1_26

学习前言

一起来看看如何利用mobilenet系列搭建yolov4目标检测平台。在这里插入图片描述

源码下载

https://github.com/bubbliiiing/mobilenet-yolov4-keras
喜欢的可以点个star噢。

网络替换实现思路

1、网络结构解析与替换思路解析

在这里插入图片描述
对于YoloV4而言,其整个网络结构可以分为三个部分。
分别是:
1、主干特征提取网络Backbone,对应图像上的CSPdarknet53
2、加强特征提取网络,对应图像上的SPP和PANet
3、预测网络YoloHead,利用获得到的特征进行预测

其中:
第一部分主干特征提取网络的功能是进行初步的特征提取,利用主干特征提取网络,我们可以获得三个初步的有效特征层
第二部分加强特征提取网络的功能是进行加强的特征提取,利用加强特征提取网络,我们可以对三个初步的有效特征层进行特征融合,提取出更好的特征,获得三个更有效的有效特征层
第三部分预测网络的功能是利用更有效的有效特整层获得预测结果。

在这三部分中,第1部分和第2部分可以更容易去修改。第3部分可修改内容不大,毕竟本身也只是3x3卷积和1x1卷积的组合。

mobilenet系列网络可用于进行分类,其主干部分的作用是进行特征提取,我们可以使用mobilenet系列网络代替yolov4当中的CSPdarknet53进行特征提取,将三个初步的有效特征层相同shape的特征层进行加强特征提取,便可以将mobilenet系列替换进yolov4当中了。

2、mobilenet系列网络介绍

本文共用到三个主干特征提取网络,分别是mobilenetV1、mobilenetV2、mobilenetV3。

a、mobilenetV1介绍

MobileNet模型是Google针对手机等嵌入式设备提出的一种轻量级的深层神经网络,其使用的核心思想便是depthwise separable convolution(深度可分离卷积块)。

对于一个卷积点而言:
假设有一个3×3大小的卷积层,其输入通道为16、输出通道为32。具体为,32个3×3大小的卷积核会遍历16个通道中的每个数据,最后可得到所需的32个输出通道,所需参数为16×32×3×3=4608个。

应用深度可分离卷积结构块,用16个3×3大小的卷积核分别遍历16通道的数据,得到了16个特征图谱。在融合操作之前,接着用32个1×1大小的卷积核遍历这16个特征图谱,所需参数为16×3×3+16×32×1×1=656个。
可以看出来depthwise separable convolution可以减少模型的参数。

如下这张图就是depthwise separable convolution的结构
在这里插入图片描述
在建立模型的时候,可以使用Keras中的DepthwiseConv2D层实现深度可分离卷积,然后再利用1x1卷积调整channels数。

通俗地理解就是3x3的卷积核厚度只有一层,然后在输入张量上一层一层地滑动,每一次卷积完生成一个输出通道,当卷积完成后,在利用1x1的卷积调整厚度。

如下就是MobileNet的结构,其中Conv dw就是分层卷积,在其之后都会接一个1x1的卷积进行通道处理,
在这里插入图片描述
上图所示是的mobilenetV1-1的结构,我们可以设置mobilenetV1的alpha值改变它的通道数。

对于yolov4来讲,我们需要取出它的最后三个shape的有效特征层进行加强特征提取。

在代码中,我们取出了out1、out2、out3。

#-------------------------------------------------------------#
#   MobileNet的网络部分
#-------------------------------------------------------------#
import numpy as np
import tensorflow as tf
from keras import backend as K
from keras.layers import Conv2D, Add, ZeroPadding2D, UpSampling2D, Concatenate, MaxPooling2D, Activation, DepthwiseConv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.regularizers import l2
import keras.backend as backend

def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha,
depth_multiplier=1, strides=(1, 1), block_id=1):

pointwise_conv_filters <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>pointwise_conv_filters <span class="token operator">*</span> alpha<span class="token punctuation">)</span>

<span class="token comment"># 深度可分离卷积</span>
x <span class="token operator">=</span> DepthwiseConv2D<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                    padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                    depth_multiplier<span class="token operator">=</span>depth_multiplier<span class="token punctuation">,</span>
                    strides<span class="token operator">=</span>strides<span class="token punctuation">,</span>
                    use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
                    name<span class="token operator">=</span><span class="token string">'conv_dw_%d'</span> <span class="token operator">%</span> block_id<span class="token punctuation">)</span><span class="token punctuation">(</span>inputs<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv_dw_%d_bn'</span> <span class="token operator">%</span> block_id<span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span>relu6<span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'conv_dw_%d_relu'</span> <span class="token operator">%</span> block_id<span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

<span class="token comment"># 1x1卷积</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>pointwise_conv_filters<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
           padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
           use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
           strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
           name<span class="token operator">=</span><span class="token string">'conv_pw_%d'</span> <span class="token operator">%</span> block_id<span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv_pw_%d_bn'</span> <span class="token operator">%</span> block_id<span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
<span class="token keyword">return</span> Activation<span class="token punctuation">(</span>relu6<span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'conv_pw_%d_relu'</span> <span class="token operator">%</span> block_id<span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
filters = int(filters * alpha)
x = Conv2D(filters, kernel,
padding=‘same’,
use_bias=False,
strides=strides,
name=‘conv1’)(inputs)
x = BatchNormalization(name=‘conv1_bn’)(x)
return Activation(relu6, name=‘conv1_relu’)(x)

def relu6(x):
return K.relu(x, max_value=6)

def MobileNetV1(inputs,alpha=1,depth_multiplier=1):
if alpha not in [0.25, 0.5, 0.75, 1.0]:
raise ValueError(‘Unsupported alpha - {} in MobilenetV1, Use 0.25, 0.5, 0.75, 1.0’.format(alpha))

<span class="token comment"># 416,416,3 -&gt; 208,208,32</span>
x <span class="token operator">=</span> _conv_block<span class="token punctuation">(</span>inputs<span class="token punctuation">,</span> <span class="token number">32</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment"># 208,208,32 -&gt; 208,208,64</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>

<span class="token comment"># 208,208,64 -&gt; 104,104,128</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span>
                          strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span>

<span class="token comment"># 104,104.128 -&gt; 64,64,256</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span>
                          strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">)</span>
feat1 <span class="token operator">=</span> x

<span class="token comment"># 64,64,256 -&gt; 32,32,512</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span>
                          strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">8</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">9</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">11</span><span class="token punctuation">)</span>
feat2 <span class="token operator">=</span> x

<span class="token comment"># 32,32,512 -&gt; 16,16,1024</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span>
                          strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">12</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> depth_multiplier<span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">13</span><span class="token punctuation">)</span>
feat3 <span class="token operator">=</span> x

<span class="token keyword">return</span> feat1<span class="token punctuation">,</span>feat2<span class="token punctuation">,</span>feat3

if name == main:
from keras.layers import Input
from keras.models import Model
alpha = 0.25
inputs = Input([None,None,3])
outputs = MobileNetV1(inputs,alpha=alpha)
model = Model(inputs,outputs)
model.summary()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97

b、mobilenetV2介绍

MobileNetV2是MobileNet的升级版,它具有一个非常重要的特点就是使用了Inverted resblock,整个mobilenetv2都由Inverted resblock组成。

Inverted resblock可以分为两个部分:
左边是主干部分,首先利用1x1卷积进行升维,然后利用3x3深度可分离卷积进行特征提取,然后再利用1x1卷积降维
右边是残差边部分,输入和输出直接相接
在这里插入图片描述

整体网络结构如下:(其中Inverted resblock进行的操作就是上述结构)
在这里插入图片描述

#-------------------------------------------------------------#
#   MobileNetV2的网络部分
#-------------------------------------------------------------#
import math
import numpy as np
import tensorflow as tf
from keras import backend
from keras.preprocessing import image
from keras.models import Model
from keras.layers.normalization import BatchNormalization
from keras.layers import Conv2D, Add, ZeroPadding2D, GlobalAveragePooling2D, Dropout, Dense
from keras.layers import MaxPooling2D,Activation,DepthwiseConv2D,Input,GlobalMaxPooling2D
from keras.applications import imagenet_utils
from keras.applications.imagenet_utils import decode_predictions
from keras.utils.data_utils import get_file

# TODO Change path to v1.1
BASE_WEIGHT_PATH = (‘https://github.com/JonathanCMitchell/mobilenet_v2_keras/’
‘releases/download/v1.1/’)

# relu6!
def relu6(x):
return backend.relu(x, max_value=6)

# 用于计算padding的大小
def correct_pad(inputs, kernel_size):
img_dim = 1
input_size = backend.int_shape(inputs)[img_dim:(img_dim + 2)]

<span class="token keyword">if</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>kernel_size<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    kernel_size <span class="token operator">=</span> <span class="token punctuation">(</span>kernel_size<span class="token punctuation">,</span> kernel_size<span class="token punctuation">)</span>

<span class="token keyword">if</span> input_size<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
    adjust <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
    adjust <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token number">1</span> <span class="token operator">-</span> input_size<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">%</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span> <span class="token operator">-</span> input_size<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">%</span> <span class="token number">2</span><span class="token punctuation">)</span>

correct <span class="token operator">=</span> <span class="token punctuation">(</span>kernel_size<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">//</span> <span class="token number">2</span><span class="token punctuation">,</span> kernel_size<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">//</span> <span class="token number">2</span><span class="token punctuation">)</span>

<span class="token keyword">return</span> <span class="token punctuation">(</span><span class="token punctuation">(</span>correct<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">-</span> adjust<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> correct<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
        <span class="token punctuation">(</span>correct<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">-</span> adjust<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> correct<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

# 使其结果可以被8整除,因为使用到了膨胀系数α
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor divisor)
if new_v < 0.9 v:
new_v += divisor
return new_v

def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id):
in_channels = backend.int_shape(inputs)[-1]
pointwise_conv_filters = int(filters * alpha)
pointwise_filters = _make_divisible(pointwise_conv_filters, 8)

x <span class="token operator">=</span> inputs
prefix <span class="token operator">=</span> <span class="token string">'block_{}_'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>block_id<span class="token punctuation">)</span>
<span class="token comment"># part1 数据扩张</span>
<span class="token keyword">if</span> block_id<span class="token punctuation">:</span>
    <span class="token comment"># Expand</span>
    x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>expansion <span class="token operator">*</span> in_channels<span class="token punctuation">,</span>
                      kernel_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                      padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                      use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
                      activation<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span>
                      name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'expand'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>epsilon<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">3</span><span class="token punctuation">,</span>
                                  momentum<span class="token operator">=</span><span class="token number">0.999</span><span class="token punctuation">,</span>
                                  name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'expand_BN'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    x <span class="token operator">=</span> Activation<span class="token punctuation">(</span>relu6<span class="token punctuation">,</span> name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'expand_relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
    prefix <span class="token operator">=</span> <span class="token string">'expanded_conv_'</span>

<span class="token keyword">if</span> stride <span class="token operator">==</span> <span class="token number">2</span><span class="token punctuation">:</span>
    x <span class="token operator">=</span> ZeroPadding2D<span class="token punctuation">(</span>padding<span class="token operator">=</span>correct_pad<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                             name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'pad'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

<span class="token comment"># part2 可分离卷积</span>
x <span class="token operator">=</span> DepthwiseConv2D<span class="token punctuation">(</span>kernel_size<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span>
                           strides<span class="token operator">=</span>stride<span class="token punctuation">,</span>
                           activation<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span>
                           use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
                           padding<span class="token operator">=</span><span class="token string">'same'</span> <span class="token keyword">if</span> stride <span class="token operator">==</span> <span class="token number">1</span> <span class="token keyword">else</span> <span class="token string">'valid'</span><span class="token punctuation">,</span>
                           name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'depthwise'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>epsilon<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">3</span><span class="token punctuation">,</span>
                              momentum<span class="token operator">=</span><span class="token number">0.999</span><span class="token punctuation">,</span>
                              name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'depthwise_BN'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> Activation<span class="token punctuation">(</span>relu6<span class="token punctuation">,</span> name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'depthwise_relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

<span class="token comment"># part3压缩特征,而且不使用relu函数,保证特征不被破坏</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>pointwise_filters<span class="token punctuation">,</span>
                  kernel_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                  padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                  use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
                  activation<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span>
                  name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'project'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>epsilon<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">3</span><span class="token punctuation">,</span> momentum<span class="token operator">=</span><span class="token number">0.999</span><span class="token punctuation">,</span> name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'project_BN'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

<span class="token keyword">if</span> in_channels <span class="token operator">==</span> pointwise_filters <span class="token keyword">and</span> stride <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> Add<span class="token punctuation">(</span>name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'add'</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>inputs<span class="token punctuation">,</span> x<span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> x

def MobileNetV2(inputs, alpha=1.0):
if alpha not in [0.5, 0.75, 1.0, 1.3]:
raise ValueError(‘Unsupported alpha - {} in MobilenetV2, Use 0.5, 0.75, 1.0, 1.3’.format(alpha))
# stem部分
first_block_filters = _make_divisible(32 * alpha, 8)
x = ZeroPadding2D(padding=correct_pad(inputs, 3),
name=‘Conv1_pad’)(inputs)
# 416,416,3 -> 208,208,32
x = Conv2D(first_block_filters,
kernel_size=3,
strides=(2, 2),
padding=‘valid’,
use_bias=False,
name=‘Conv1’)(x)
x = BatchNormalization(epsilon=1e-3,
momentum=0.999,
name=‘bn_Conv1’)(x)
x = Activation(relu6, name=‘Conv1_relu’)(x)

<span class="token comment"># 208,208,32 -&gt; 208,208,16</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">16</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
<span class="token comment"># 208,208,16 -&gt; 104,104,24</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">24</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">24</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>

<span class="token comment"># 104,104,24 -&gt; 52,52,32</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">32</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">32</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">32</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">)</span>
feat1 <span class="token operator">=</span> x

<span class="token comment"># 52,52,32 -&gt; 26,26,96</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">8</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">9</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">96</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">96</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">11</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">96</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">12</span><span class="token punctuation">)</span>
feat2 <span class="token operator">=</span> x

<span class="token comment"># 26,26,96 -&gt; 13,13,320</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">160</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">13</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">160</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">14</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">160</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">15</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _inverted_res_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token number">320</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span>alpha<span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                        expansion<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">,</span> block_id<span class="token operator">=</span><span class="token number">16</span><span class="token punctuation">)</span>
feat3 <span class="token operator">=</span> x

<span class="token keyword">return</span> feat1<span class="token punctuation">,</span>feat2<span class="token punctuation">,</span>feat3
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175

c、mobilenetV3介绍

mobilenetV3使用了特殊的bneck结构。

bneck结构如下图所示:
在这里插入图片描述
它综合了以下四个特点:
a、MobileNetV2的具有线性瓶颈的逆残差结构(the inverted residual with linear bottleneck)。
在这里插入图片描述
即先利用1x1卷积进行升维度,再进行下面的操作,并具有残差边。

b、MobileNetV1的深度可分离卷积(depthwise separable convolutions)。
在这里插入图片描述
在输入1x1卷积进行升维度后,进行3x3深度可分离卷积。

c、轻量级的注意力模型。
在这里插入图片描述
这个注意力机制的作用方式是调整每个通道的权重。

d、利用h-swish代替swish函数。
在结构中使用了h-swishj激活函数,代替swish函数,减少运算量,提高性能。
在这里插入图片描述

下图为整个mobilenetV3的结构图:
在这里插入图片描述
如何看懂这个表呢?我们从每一列出发:
第一列Input代表mobilenetV3每个特征层的shape变化;
第二列Operator代表每次特征层即将经历的block结构,我们可以看到在MobileNetV3中,特征提取经过了许多的bneck结构;
第三、四列分别代表了bneck内逆残差结构上升后的通道数、输入到bneck时特征层的通道数。
第五列SE代表了是否在这一层引入注意力机制。
第六列NL代表了激活函数的种类,HS代表h-swish,RE代表RELU。
第七列s代表了每一次block结构所用的步长。

from keras.layers import Conv2D, DepthwiseConv2D, Dense, GlobalAveragePooling2D, Input
from keras.layers import Activation, BatchNormalization, Add, Multiply, Reshape, Multiply
from keras.models import Model
from keras import backend

def _activation(x, name=‘relu’):
if name ‘relu’:
return Activation(‘relu’)(x)
elif name ‘hardswish’:
return hard_swish(x)

def hard_sigmoid(x):
return backend.relu(x + 3.0, max_value=6.0) / 6.0

def hard_swish(x):
return Multiply()([Activation(hard_sigmoid)(x), x])

def _make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 v:
new_v += divisor
return new_v

def _bneck(inputs, expansion, alpha, out_ch, kernel_size, stride, se_ratio, activation,
block_id):
channel_axis = 1 if backend.image_data_format() == ‘channels_first’ else -1

in_channels <span class="token operator">=</span> backend<span class="token punctuation">.</span>int_shape<span class="token punctuation">(</span>inputs<span class="token punctuation">)</span><span class="token punctuation">[</span>channel_axis<span class="token punctuation">]</span>
out_channels <span class="token operator">=</span> _make_divisible<span class="token punctuation">(</span>out_ch <span class="token operator">*</span> alpha<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">)</span>
exp_size <span class="token operator">=</span> _make_divisible<span class="token punctuation">(</span>in_channels <span class="token operator">*</span> expansion<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> inputs
prefix <span class="token operator">=</span> <span class="token string">'expanded_conv/'</span>
<span class="token keyword">if</span> block_id<span class="token punctuation">:</span>
    <span class="token comment"># Expand</span>
    prefix <span class="token operator">=</span> <span class="token string">'expanded_conv_{}/'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>block_id<span class="token punctuation">)</span>
    x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>exp_size<span class="token punctuation">,</span>
                      kernel_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                      padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                      use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
                      name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'expand'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>channel_axis<span class="token punctuation">,</span>
                                  name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'expand/BatchNorm'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    x <span class="token operator">=</span> _activation<span class="token punctuation">(</span>x<span class="token punctuation">,</span> activation<span class="token punctuation">)</span>

x <span class="token operator">=</span> DepthwiseConv2D<span class="token punctuation">(</span>kernel_size<span class="token punctuation">,</span>
                           strides<span class="token operator">=</span>stride<span class="token punctuation">,</span>
                           padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                           dilation_rate<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                           use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
                           name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'depthwise'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>channel_axis<span class="token punctuation">,</span>
                              name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'depthwise/BatchNorm'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> _activation<span class="token punctuation">(</span>x<span class="token punctuation">,</span> activation<span class="token punctuation">)</span>

<span class="token keyword">if</span> se_ratio<span class="token punctuation">:</span>
    reduced_ch <span class="token operator">=</span> _make_divisible<span class="token punctuation">(</span>exp_size <span class="token operator">*</span> se_ratio<span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">)</span>
    y <span class="token operator">=</span> GlobalAveragePooling2D<span class="token punctuation">(</span>name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'squeeze_excite/AvgPool'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
    y <span class="token operator">=</span> Reshape<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> exp_size<span class="token punctuation">]</span><span class="token punctuation">,</span> name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'reshape'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>y<span class="token punctuation">)</span>
    y <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>reduced_ch<span class="token punctuation">,</span>
                      kernel_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                      padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                      use_bias<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span>
                      name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'squeeze_excite/Conv'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>y<span class="token punctuation">)</span>
    y <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">"relu"</span><span class="token punctuation">,</span> name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'squeeze_excite/Relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>y<span class="token punctuation">)</span>
    y <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>exp_size<span class="token punctuation">,</span>
                      kernel_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                      padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                      use_bias<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span>
                      name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'squeeze_excite/Conv_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>y<span class="token punctuation">)</span>
    x <span class="token operator">=</span> Multiply<span class="token punctuation">(</span>name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'squeeze_excite/Mul'</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>Activation<span class="token punctuation">(</span>hard_sigmoid<span class="token punctuation">)</span><span class="token punctuation">(</span>y<span class="token punctuation">)</span><span class="token punctuation">,</span> x<span class="token punctuation">]</span><span class="token punctuation">)</span>

x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>out_channels<span class="token punctuation">,</span>
                  kernel_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
                  padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                  use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
                  name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'project'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>channel_axis<span class="token punctuation">,</span>
                              name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'project/BatchNorm'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

<span class="token keyword">if</span> in_channels <span class="token operator">==</span> out_channels <span class="token keyword">and</span> stride <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">:</span>
    x <span class="token operator">=</span> Add<span class="token punctuation">(</span>name<span class="token operator">=</span>prefix <span class="token operator">+</span> <span class="token string">'Add'</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>inputs<span class="token punctuation">,</span> x<span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> x

def MobileNetV3(inputs, alpha=1.0, kernel=5, se_ratio=0.25):
if alpha not in [0.75, 1.0]:
raise ValueError(‘Unsupported alpha - {} in MobilenetV3, Use 0.75, 1.0.’.format(alpha))
# 416,416,3 -> 208,208,16
x = Conv2D(16,kernel_size=3,strides=(2, 2),padding=‘same’,
use_bias=False,
name=‘Conv’)(inputs)
x = BatchNormalization(axis=-1,
epsilon=1e-3,
momentum=0.999,
name=‘Conv/BatchNorm’)(x)
x = Activation(hard_swish)(x)

<span class="token comment"># 208,208,16 -&gt; 208,208,16</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">16</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">'relu'</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span>

<span class="token comment"># 208,208,16 -&gt; 104,104,24</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">24</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">'relu'</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">24</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">'relu'</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span>

<span class="token comment"># 104,104,24 -&gt; 52,52,40</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">40</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> kernel<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> se_ratio<span class="token punctuation">,</span> <span class="token string">'relu'</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">40</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> kernel<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> se_ratio<span class="token punctuation">,</span> <span class="token string">'relu'</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">40</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> kernel<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> se_ratio<span class="token punctuation">,</span> <span class="token string">'relu'</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">)</span>
feat1 <span class="token operator">=</span> x

<span class="token comment"># 52,52,40 -&gt; 26,26,112</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">80</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">2.5</span><span class="token punctuation">,</span> <span class="token number">80</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">7</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">2.3</span><span class="token punctuation">,</span> <span class="token number">80</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">2.3</span><span class="token punctuation">,</span> <span class="token number">80</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">9</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">112</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> se_ratio<span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">10</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">112</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> se_ratio<span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">11</span><span class="token punctuation">)</span>
feat2 <span class="token operator">=</span> x

<span class="token comment"># 26,26,112 -&gt; 13,13,160</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">160</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> kernel<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> se_ratio<span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">12</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">160</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> kernel<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> se_ratio<span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">13</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> _bneck<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">160</span><span class="token punctuation">,</span> alpha<span class="token punctuation">,</span> kernel<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> se_ratio<span class="token punctuation">,</span> <span class="token string">'hardswish'</span><span class="token punctuation">,</span> <span class="token number">14</span><span class="token punctuation">)</span>
feat3 <span class="token operator">=</span> x

<span class="token keyword">return</span> feat1<span class="token punctuation">,</span>feat2<span class="token punctuation">,</span>feat3
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128

3、将特征提取结果融入到yolov4网络当中

在这里插入图片描述
对于yolov4来讲,我们需要利用主干特征提取网络获得的三个有效特征进行加强特征金字塔的构建

利用上一步定义的MobilenetV1、MobilenetV2、MobilenetV3三个函数我们可以获得每个Mobilenet网络对应的三个有效特征层。

我们可以利用这三个有效特征层替换原来yolov4主干网络CSPdarknet53的有效特征层。

为了进一步减少参数量,我们可以使用深度可分离卷积代替yoloV4中用到的普通卷积。

实现代码如下:

from functools import wraps

import numpy as np
import tensorflow as tf
from keras import backend as K
from keras.layers import Conv2D, Add, ZeroPadding2D, UpSampling2D, Concatenate, MaxPooling2D, Activation, DepthwiseConv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.regularizers import l2
from nets.mobilenet_v1 import MobileNetV1
from nets.mobilenet_v2 import MobileNetV2
from nets.mobilenet_v3 import MobileNetV3
from utils.utils import compose

def relu6(x):
return K.relu(x, max_value=6)

#--------------------------------------------------#
# 单次卷积
#--------------------------------------------------#
@wraps(Conv2D)
def DarknetConv2D(args, **kwargs):
darknet_conv_kwargs = { }
darknet_conv_kwargs[‘padding’] = ‘valid’ if kwargs.get(‘strides’)==(2,2) else ‘same’
darknet_conv_kwargs.update(kwargs)
return Conv2D(args, **darknet_conv_kwargs)

#---------------------------------------------------#
# 卷积块
# Conv2D + BatchNormalization + LeakyReLU
#---------------------------------------------------#
def DarknetConv2D_BN_Leaky(args, **kwargs):
no_bias_kwargs = { ‘use_bias’: False}
no_bias_kwargs.update(kwargs)
return compose(
DarknetConv2D(args, **no_bias_kwargs),
BatchNormalization(),
Activation(relu6))

#---------------------------------------------------#
# 卷积块
# DepthwiseConv2D + BatchNormalization + Relu6
#---------------------------------------------------#
def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha = 1,
depth_multiplier=1, strides=(1, 1), block_id=1):

pointwise_conv_filters <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>pointwise_conv_filters <span class="token operator">*</span> alpha<span class="token punctuation">)</span>

x <span class="token operator">=</span> DepthwiseConv2D<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                    padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                    depth_multiplier<span class="token operator">=</span>depth_multiplier<span class="token punctuation">,</span>
                    strides<span class="token operator">=</span>strides<span class="token punctuation">,</span>
                    use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">(</span>inputs<span class="token punctuation">)</span>

x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span>relu6<span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>pointwise_conv_filters<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
           padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
           use_bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
           strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
<span class="token keyword">return</span> Activation<span class="token punctuation">(</span>relu6<span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

#---------------------------------------------------#
# 特征层->最后的输出
#---------------------------------------------------#
def make_five_convs(x, num_filters):
# 五次卷积
x = DarknetConv2D_BN_Leaky(num_filters, (1,1))(x)
x = _depthwise_conv_block(x, num_filters2,alpha=1)
x = DarknetConv2D_BN_Leaky(num_filters, (1,1))(x)
x = _depthwise_conv_block(x, num_filters2,alpha=1)
x = DarknetConv2D_BN_Leaky(num_filters, (1,1))(x)
return x

#---------------------------------------------------#
# 特征层->最后的输出
#---------------------------------------------------#
def yolo_body(inputs, num_anchors, num_classes, backbone=“mobilenetv1”, alpha=1):
# 生成darknet53的主干模型
if backbone"mobilenetv1":
feat1,feat2,feat3 = MobileNetV1(inputs, alpha=alpha)
elif backbone"mobilenetv2":
feat1,feat2,feat3 = MobileNetV2(inputs, alpha=alpha)
elif backbone==“mobilenetv3”:
feat1,feat2,feat3 = MobileNetV3(inputs, alpha=alpha)
else:
raise ValueError(‘Unsupported backbone - {}, Use mobilenetv1, mobilenetv2, mobilenetv3.’.format(backbone))

P5 <span class="token operator">=</span> DarknetConv2D_BN_Leaky<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">512</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>feat3<span class="token punctuation">)</span>
P5 <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>P5<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">1024</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>
P5 <span class="token operator">=</span> DarknetConv2D_BN_Leaky<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">512</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
maxpool1 <span class="token operator">=</span> MaxPooling2D<span class="token punctuation">(</span>pool_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">13</span><span class="token punctuation">,</span><span class="token number">13</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
maxpool2 <span class="token operator">=</span> MaxPooling2D<span class="token punctuation">(</span>pool_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">9</span><span class="token punctuation">,</span><span class="token number">9</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
maxpool3 <span class="token operator">=</span> MaxPooling2D<span class="token punctuation">(</span>pool_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span><span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
P5 <span class="token operator">=</span> Concatenate<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>maxpool1<span class="token punctuation">,</span> maxpool2<span class="token punctuation">,</span> maxpool3<span class="token punctuation">,</span> P5<span class="token punctuation">]</span><span class="token punctuation">)</span>
P5 <span class="token operator">=</span> DarknetConv2D_BN_Leaky<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">512</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P5<span class="token punctuation">)</span>
P5 <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>P5<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">1024</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>
P5 <span class="token operator">=</span> DarknetConv2D_BN_Leaky<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">512</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P5<span class="token punctuation">)</span>

P5_upsample <span class="token operator">=</span> compose<span class="token punctuation">(</span>DarknetConv2D_BN_Leaky<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">256</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> UpSampling2D<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P5<span class="token punctuation">)</span>

P4 <span class="token operator">=</span> DarknetConv2D_BN_Leaky<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">256</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>feat2<span class="token punctuation">)</span>
P4 <span class="token operator">=</span> Concatenate<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>P4<span class="token punctuation">,</span> P5_upsample<span class="token punctuation">]</span><span class="token punctuation">)</span>
P4 <span class="token operator">=</span> make_five_convs<span class="token punctuation">(</span>P4<span class="token punctuation">,</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">256</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>

P4_upsample <span class="token operator">=</span> compose<span class="token punctuation">(</span>DarknetConv2D_BN_Leaky<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">128</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> UpSampling2D<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P4<span class="token punctuation">)</span>

P3 <span class="token operator">=</span> DarknetConv2D_BN_Leaky<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">128</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>feat1<span class="token punctuation">)</span>
P3 <span class="token operator">=</span> Concatenate<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>P3<span class="token punctuation">,</span> P4_upsample<span class="token punctuation">]</span><span class="token punctuation">)</span>
P3 <span class="token operator">=</span> make_five_convs<span class="token punctuation">(</span>P3<span class="token punctuation">,</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">128</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>

P3_output <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>P3<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">256</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>
P3_output <span class="token operator">=</span> DarknetConv2D<span class="token punctuation">(</span>num_anchors<span class="token operator">*</span><span class="token punctuation">(</span>num_classes<span class="token operator">+</span><span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P3_output<span class="token punctuation">)</span>

<span class="token comment">#26,26 output</span>
P3_downsample <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>P3<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">256</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
P4 <span class="token operator">=</span> Concatenate<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>P3_downsample<span class="token punctuation">,</span> P4<span class="token punctuation">]</span><span class="token punctuation">)</span>
P4 <span class="token operator">=</span> make_five_convs<span class="token punctuation">(</span>P4<span class="token punctuation">,</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">256</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>

P4_output <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>P4<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">512</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>
P4_output <span class="token operator">=</span> DarknetConv2D<span class="token punctuation">(</span>num_anchors<span class="token operator">*</span><span class="token punctuation">(</span>num_classes<span class="token operator">+</span><span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P4_output<span class="token punctuation">)</span>

<span class="token comment">#13,13 output</span>
P4_downsample <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>P4<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">512</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
P5 <span class="token operator">=</span> Concatenate<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>P4_downsample<span class="token punctuation">,</span> P5<span class="token punctuation">]</span><span class="token punctuation">)</span>
P5 <span class="token operator">=</span> make_five_convs<span class="token punctuation">(</span>P5<span class="token punctuation">,</span><span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">512</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>

P5_output <span class="token operator">=</span> _depthwise_conv_block<span class="token punctuation">(</span>P5<span class="token punctuation">,</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">1024</span><span class="token operator">*</span> alpha<span class="token punctuation">)</span><span class="token punctuation">)</span>
P5_output <span class="token operator">=</span> DarknetConv2D<span class="token punctuation">(</span>num_anchors<span class="token operator">*</span><span class="token punctuation">(</span>num_classes<span class="token operator">+</span><span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>P5_output<span class="token punctuation">)</span>

<span class="token keyword">return</span> Model<span class="token punctuation">(</span>inputs<span class="token punctuation">,</span> <span class="token punctuation">[</span>P5_output<span class="token punctuation">,</span> P4_output<span class="token punctuation">,</span> P3_output<span class="token punctuation">]</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134

如何训练自己的mobilenet-yolo4

首先前往Github下载对应的仓库,下载完后利用解压软件解压,之后用编程软件打开文件夹。
注意打开的根目录必须正确,否则相对目录不正确的情况下,代码将无法运行。

一定要注意打开后的根目录是文件存放的目录。
在这里插入图片描述

一、数据集的准备

本文使用VOC格式进行训练,训练前需要自己制作好数据集,如果没有自己的数据集,可以通过Github连接下载VOC12+07的数据集尝试下。
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
在这里插入图片描述
训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
在这里插入图片描述
此时数据集的摆放已经结束。

二、数据集的处理

在完成数据集的摆放之后,我们需要对数据集进行下一步的处理,目的是获得训练用的2007_train.txt以及2007_val.txt,需要用到根目录下的voc_annotation.py。

voc_annotation.py里面有一些参数需要设置。
分别是annotation_mode、classes_path、trainval_percent、train_percent、VOCdevkit_path,第一次训练可以仅修改classes_path

'''
annotation_mode用于指定该文件运行时计算的内容
annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
'''
annotation_mode     = 0
'''
必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
与训练和预测所用的classes_path一致即可
如果生成的2007_train.txt里面没有目标信息
那么就是因为classes没有设定正确
仅在annotation_mode为0和2的时候有效
'''
classes_path        = 'model_data/voc_classes.txt'
'''
trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
仅在annotation_mode为0和1的时候有效
'''
trainval_percent    = 0.9
train_percent       = 0.9
'''
指向VOC数据集所在的文件夹
默认指向根目录下的VOC数据集
'''
VOCdevkit_path  = 'VOCdevkit'

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

classes_path用于指向检测类别所对应的txt,以voc数据集为例,我们用的txt为:
在这里插入图片描述
训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。

三、开始网络训练

通过voc_annotation.py我们已经生成了2007_train.txt以及2007_val.txt,此时我们可以开始训练了。
训练的参数较多,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。

classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!
在这里插入图片描述
修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。

另外,backbone参数用于指定所用的主干特征提取网络,可以在mobilenetv1, mobilenetv2, mobilenetv3中进行选择。

alpha参数用于指定当前所使用的mobilenet系列网络的通道变化情况,默认状态下为1。
mobilenetv1的alpha可选范围为0.25、0.5、0.75、1.0。
mobilenetv2的alpha可选范围为0.5、0.75、1.0、1.3。
mobilenetv3的alpha可选范围为0.75、1.0。

训练前需要注意所用mobilenet版本、alpha值和预训练权重的对齐。

其它参数的作用如下:

#--------------------------------------------------------#
#   训练前一定要修改classes_path,使其对应自己的数据集
#--------------------------------------------------------#
classes_path    = 'model_data/voc_classes.txt'
#---------------------------------------------------------------------#
#   anchors_path代表先验框对应的txt文件,一般不修改。
#   anchors_mask用于帮助代码找到对应的先验框,一般不修改。
#---------------------------------------------------------------------#
anchors_path    = 'model_data/yolo_anchors.txt'
anchors_mask    = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
#------------------------------------------------------------------------------------------------------#
#   权值文件请看README,百度网盘下载。数据的预训练权重对不同数据集是通用的,因为特征是通用的
#   预训练权重对于99%的情况都必须要用,不用的话权值太过随机,特征提取效果不明显,网络训练的结果也不会好。
#   训练自己的数据集时提示维度不匹配正常,预测的东西都不一样了自然维度不匹配
#   如果想要断点续练就将model_path设置成logs文件夹下已经训练的权值文件。 
#------------------------------------------------------------------------------------------------------#
model_path      = 'model_data/yolov4_mobilenet_v2_voc.h5'
#------------------------------------------------------#
#   输入的shape大小,一定要是32的倍数
#------------------------------------------------------#
input_shape     = [416, 416]
#--------------------------------------------------#
#   一定要注意backbone、alpha与权值文件的对应!
#   mobilenetv1可选的alpha有0.25, 0.5, 0.75, 1.0
#   mobilenetv2可选的alpha有0.5, 0.75, 1.0, 1.3
#   mobilenetv3可选的alpha有0.75, 1.0
#   ghostnet可选的alpha有1.0
#   权值文件的下载请看README
#--------------------------------------------------#
backbone        = "mobilenetv2"
alpha           = 1
#------------------------------------------------------#
#   Yolov4的tricks应用
#   mosaic 马赛克数据增强 True or False 
#   实际测试时mosaic数据增强并不稳定,所以默认为False
#   Cosine_scheduler 余弦退火学习率 True or False
#   label_smoothing 标签平滑 0.01以下一般 如0.01、0.005
#------------------------------------------------------#
mosaic              = False
Cosine_scheduler    = False
label_smoothing     = 0

#----------------------------------------------------#
# 训练分为两个阶段,分别是冻结阶段和解冻阶段。
# 显存不足与数据集大小无关,提示显存不足请调小batch_size。
# 受到BatchNorm层影响,batch_size最小为1。
#----------------------------------------------------#
#----------------------------------------------------#
# 冻结阶段训练参数
# 此时模型的主干被冻结了,特征提取网络不发生改变
# 占用的显存较小,仅对网络进行微调
#----------------------------------------------------#
Init_Epoch = 0
Freeze_Epoch = 50
Freeze_batch_size = 16
Freeze_lr = 1e-3
#----------------------------------------------------#
# 解冻阶段训练参数
# 此时模型的主干不被冻结了,特征提取网络会发生改变
# 占用的显存较大,网络所有的参数都会发生改变
#----------------------------------------------------#
UnFreeze_Epoch = 100
Unfreeze_batch_size = 8
Unfreeze_lr = 1e-4
#------------------------------------------------------#
# 是否进行冻结训练,默认先冻结主干训练后解冻训练。
#------------------------------------------------------#
Freeze_Train = True
#------------------------------------------------------#
# 用于设置是否使用多线程读取数据,0代表关闭多线程
# 开启后会加快数据读取速度,但是会占用更多内存
# keras里开启多线程有些时候速度反而慢了许多
# 在IO为瓶颈的时候再开启多线程,即GPU运算速度远大于读取图片的速度。
#------------------------------------------------------#
num_workers = 0
#----------------------------------------------------#
# 获得图片路径和标签
#----------------------------------------------------#
train_annotation_path = ‘2007_train.txt’
val_annotation_path = ‘2007_val.txt’

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81

四、训练结果预测

训练结果预测需要用到两个文件,分别是yolo.py和predict.py。
我们首先需要去yolo.py里面修改model_path以及classes_path,这两个参数必须要修改。

另外,backbone参数用于指定所用的主干特征提取网络,可以在mobilenetv1, mobilenetv2, mobilenetv3中进行选择。

alpha参数用于指定当前所使用的mobilenet系列网络的通道变化情况,默认状态下为1。
mobilenetv1的alpha可选范围为0.25、0.5、0.75、1.0。
mobilenetv2的alpha可选范围为0.5、0.75、1.0、1.3。
mobilenetv3的alpha可选范围为0.75、1.0。

model_path指向训练好的权值文件,在logs文件夹里。
classes_path指向检测类别所对应的txt。

在这里插入图片描述
完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值