(大佬)睿智的目标检测16——Keras搭建SSD目标检测平台

原文链接:https://blog.csdn.net/weixin_44791964/article/details/104107271

学习前言

一起来看看SSD的keras实现吧,顺便训练一下自己的数据。
在这里插入图片描述

什么是SSD目标检测算法

SSD是一种非常优秀的one-stage目标检测方法,one-stage算法就是目标检测和分类是同时完成的,其主要思路是利用CNN提取特征后,均匀地在图片的不同位置进行密集抽样,抽样时可以采用不同尺度和长宽比,物体分类与预测框的回归同时进行,整个过程只需要一步,所以其优势是速度快
但是均匀的密集采样的一个重要缺点是训练比较困难,这主要是因为正样本与负样本(背景)极其不均衡(参见Focal Loss),导致模型准确度稍低。
SSD的英文全名是Single Shot MultiBox Detector,Single shot说明SSD算法属于one-stage方法,MultiBox说明SSD算法基于多框预测。

源码下载

https://github.com/bubbliiiing/ssd-keras
喜欢的可以点个star噢。

SSD实现思路

一、预测部分

1、主干网络介绍

在这里插入图片描述
SSD采用的主干网络是VGG网络,关于VGG的介绍大家可以看我的另外一篇博客https://blog.csdn.net/weixin_44791964/article/details/102779878,这里的VGG网络相比普通的VGG网络有一定的修改,主要修改的地方就是:
1、将VGG16的FC6和FC7层转化为卷积层。
2、去掉所有的Dropout层和FC8层;
3、新增了Conv6、Conv7、Conv8、Conv9。

在这里插入图片描述
如图所示,输入的图片经过了改进的VGG网络(Conv1->fc7)和几个另加的卷积层(Conv6->Conv9),进行特征提取:
a、输入一张图片后,被resize到300x300的shape

b、conv1,经过两次[3,3]卷积网络,输出的特征层为64,输出为(300,300,64),再2X2最大池化,输出net为(150,150,64)。

c、conv2,经过两次[3,3]卷积网络,输出的特征层为128,输出net为(150,150,128),再2X2最大池化,输出net为(75,75,128)。

d、conv3,经过三次[3,3]卷积网络,输出的特征层为256,输出net为(75,75,256),再2X2最大池化,输出net为(38,38,256)。

e、conv4,经过三次[3,3]卷积网络,输出的特征层为512,输出net为(38,38,512),再2X2最大池化,输出net为(19,19,512)。

f、conv5,经过三次[3,3]卷积网络,输出的特征层为512,输出net为(19,19,512),再2X2最大池化,输出net为(19,19,512)。

g、利用卷积代替全连接层,进行了两次[3,3]卷积网络,输出的特征层为1024,因此输出的net为(19,19,1024)。(从这里往前都是VGG的结构)

h、conv6,经过一次[1,1]卷积网络,调整通道数,一次步长为2的[3,3]卷积网络,输出的特征层为512,因此输出的net为(10,10,512)。

i、conv7,经过一次[1,1]卷积网络,调整通道数,一次步长为2的[3,3]卷积网络,输出的特征层为256,因此输出的net为(5,5,256)。

j、conv8,经过一次[1,1]卷积网络,调整通道数,一次padding为valid的[3,3]卷积网络,输出的特征层为256,因此输出的net为(3,3,256)。

k、conv9,经过一次[1,1]卷积网络,调整通道数,一次padding为valid的[3,3]卷积网络,输出的特征层为256,因此输出的net为(1,1,256)。

实现代码:

import keras.backend as K
from keras.layers import Activation
from keras.layers import Conv2D
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import GlobalAveragePooling2D
from keras.layers import Input
from keras.layers import MaxPooling2D
from keras.layers import merge, concatenate
from keras.layers import Reshape
from keras.layers import ZeroPadding2D
from keras.models import Model

def VGG16(input_tensor):
#----------------------------主干特征提取网络开始---------------------------#
# SSD结构,net字典
net = {}
# Block 1
net[‘input’] = input_tensor
# 300,300,3 -> 150,150,64
net[‘conv1_1’] = Conv2D(64, kernel_size=(3,3),
activation=‘relu’,
padding=‘same’,
name=‘conv1_1’)(net[‘input’])
net[‘conv1_2’] = Conv2D(64, kernel_size=(3,3),
activation=‘relu’,
padding=‘same’,
name=‘conv1_2’)(net[‘conv1_1’])
net[‘pool1’] = MaxPooling2D((2, 2), strides=(2, 2), padding=‘same’,
name=‘pool1’)(net[‘conv1_2’])

<span class="token comment"># Block 2</span>
<span class="token comment"># 150,150,64 -&gt; 75,75,128</span>
net<span class="token punctuation">[</span><span class="token string">'conv2_1'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv2_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'pool1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv2_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv2_2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv2_1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'pool2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> MaxPooling2D<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                            name<span class="token operator">=</span><span class="token string">'pool2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv2_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># Block 3</span>
<span class="token comment"># 75,75,128 -&gt; 38,38,256</span>
net<span class="token punctuation">[</span><span class="token string">'conv3_1'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv3_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'pool2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv3_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv3_2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv3_1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv3_3'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv3_3'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv3_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'pool3'</span><span class="token punctuation">]</span> <span class="token operator">=</span> MaxPooling2D<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                            name<span class="token operator">=</span><span class="token string">'pool3'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv3_3'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># Block 4</span>
<span class="token comment"># 38,38,256 -&gt; 19,19,512</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_1'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv4_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'pool3'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv4_2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_3'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv4_3'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'pool4'</span><span class="token punctuation">]</span> <span class="token operator">=</span> MaxPooling2D<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                            name<span class="token operator">=</span><span class="token string">'pool4'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># Block 5</span>
<span class="token comment"># 19,19,512 -&gt; 19,19,512</span>
net<span class="token punctuation">[</span><span class="token string">'conv5_1'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv5_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'pool4'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv5_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv5_2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv5_1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv5_3'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv5_3'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv5_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'pool5'</span><span class="token punctuation">]</span> <span class="token operator">=</span> MaxPooling2D<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                            name<span class="token operator">=</span><span class="token string">'pool5'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv5_3'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># FC6</span>
<span class="token comment"># 19,19,512 -&gt; 19,19,1024</span>
net<span class="token punctuation">[</span><span class="token string">'fc6'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dilation_rate<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                 activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                                 name<span class="token operator">=</span><span class="token string">'fc6'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'pool5'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># x = Dropout(0.5, name='drop6')(x)</span>
<span class="token comment"># FC7</span>
<span class="token comment"># 19,19,1024 -&gt; 19,19,1024</span>
net<span class="token punctuation">[</span><span class="token string">'fc7'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                           padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'fc7'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'fc6'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># x = Dropout(0.5, name='drop7')(x)</span>
<span class="token comment"># Block 6</span>
<span class="token comment"># 19,19,512 -&gt; 10,10,512</span>
net<span class="token punctuation">[</span><span class="token string">'conv6_1'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv6_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'fc7'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv6_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> ZeroPadding2D<span class="token punctuation">(</span>padding<span class="token operator">=</span><span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'conv6_padding'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv6_1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv6_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv6_2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv6_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># Block 7</span>
<span class="token comment"># 10,10,512 -&gt; 5,5,256</span>
net<span class="token punctuation">[</span><span class="token string">'conv7_1'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span> 
                               name<span class="token operator">=</span><span class="token string">'conv7_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv6_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv7_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> ZeroPadding2D<span class="token punctuation">(</span>padding<span class="token operator">=</span><span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'conv7_padding'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv7_1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv7_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'valid'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv7_2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv7_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># Block 8</span>
<span class="token comment"># 5,5,256 -&gt; 3,3,256</span>
net<span class="token punctuation">[</span><span class="token string">'conv8_1'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv8_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv7_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv8_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'valid'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv8_2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv8_1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># Block 9</span>
<span class="token comment"># 3,3,256 -&gt; 1,1,256</span>
net<span class="token punctuation">[</span><span class="token string">'conv9_1'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span>
                               padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv9_1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv8_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv9_2'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                               activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'valid'</span><span class="token punctuation">,</span>
                               name<span class="token operator">=</span><span class="token string">'conv9_2'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv9_1'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment">#----------------------------主干特征提取网络结束---------------------------#</span>
<span class="token keyword">return</span> net
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143

2、从特征获取预测结果

在这里插入图片描述
由上图我们可以知道,我们分别取conv4的第三次卷积的特征、fc7的特征、conv6的第二次卷积的特征、conv7的第二次卷积的特征、conv8的第二次卷积的特征、conv9的第二次卷积的特征,为了和普通特征层区分,我们称之为有效特征层,来获取预测结果。

对获取到的每一个有效特征层,我们分别对其进行一次num_priors x 4的卷积、一次num_priors x num_classes的卷积、并需要计算每一个有效特征层对应的先验框。而num_priors指的是该特征层所拥有的先验框数量。

其中:
num_priors x 4的卷积 用于预测 该特征层上 每一个网格点上 每一个先验框的变化情况。(为什么说是变化情况呢,这是因为ssd的预测结果需要结合先验框获得预测框,预测结果就是先验框的变化情况。)

num_priors x num_classes的卷积 用于预测 该特征层上 每一个网格点上 每一个预测框对应的种类。

每一个有效特征层对应的先验框对应着该特征层上 每一个网格点上 预先设定好的多个框。

所有的特征层对应的预测结果的shape如下:
在这里插入图片描述
实现代码为:

def SSD300(input_shape, num_classes=21):
    # 300,300,3
    input_tensor = Input(shape=input_shape)
    img_size = (input_shape[1], input_shape[0])
<span class="token comment"># SSD结构,net字典</span>
net <span class="token operator">=</span> VGG16<span class="token punctuation">(</span>input_tensor<span class="token punctuation">)</span>
<span class="token comment">#-----------------------将提取到的主干特征进行处理---------------------------#</span>
<span class="token comment"># 对conv4_3进行处理 38,38,512</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Normalize<span class="token punctuation">(</span><span class="token number">20</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'conv4_3_norm'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
num_priors <span class="token operator">=</span> <span class="token number">4</span>
<span class="token comment"># 预测框的处理</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_loc'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'conv4_3_norm_mbox_loc'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_loc_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv4_3_norm_mbox_loc_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,num_classes是所分的类</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> num_classes<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv4_3_norm_mbox_conf'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_conf_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv4_3_norm_mbox_conf_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
priorbox <span class="token operator">=</span> PriorBox<span class="token punctuation">(</span>img_size<span class="token punctuation">,</span> <span class="token number">30.0</span><span class="token punctuation">,</span>max_size <span class="token operator">=</span> <span class="token number">60.0</span><span class="token punctuation">,</span> aspect_ratios<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    variances<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    name<span class="token operator">=</span><span class="token string">'conv4_3_norm_mbox_priorbox'</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_priorbox'</span><span class="token punctuation">]</span> <span class="token operator">=</span> priorbox<span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 对fc7层进行处理 </span>
num_priors <span class="token operator">=</span> <span class="token number">6</span>
<span class="token comment"># 预测框的处理</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整</span>
net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_loc'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'fc7_mbox_loc'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'fc7'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_loc_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'fc7_mbox_loc_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,num_classes是所分的类</span>
net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> num_classes<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span>padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'fc7_mbox_conf'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'fc7'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_conf_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'fc7_mbox_conf_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

priorbox <span class="token operator">=</span> PriorBox<span class="token punctuation">(</span>img_size<span class="token punctuation">,</span> <span class="token number">60.0</span><span class="token punctuation">,</span> max_size<span class="token operator">=</span><span class="token number">111.0</span><span class="token punctuation">,</span> aspect_ratios<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    variances<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    name<span class="token operator">=</span><span class="token string">'fc7_mbox_priorbox'</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_priorbox'</span><span class="token punctuation">]</span> <span class="token operator">=</span> priorbox<span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'fc7'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 对conv6_2进行处理</span>
num_priors <span class="token operator">=</span> <span class="token number">6</span>
<span class="token comment"># 预测框的处理</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv6_2_mbox_loc'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv6_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_loc'</span><span class="token punctuation">]</span> <span class="token operator">=</span> x
net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_loc_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv6_2_mbox_loc_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,num_classes是所分的类</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> num_classes<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv6_2_mbox_conf'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv6_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> x
net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_conf_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv6_2_mbox_conf_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

priorbox <span class="token operator">=</span> PriorBox<span class="token punctuation">(</span>img_size<span class="token punctuation">,</span> <span class="token number">111.0</span><span class="token punctuation">,</span> max_size<span class="token operator">=</span><span class="token number">162.0</span><span class="token punctuation">,</span> aspect_ratios<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    variances<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    name<span class="token operator">=</span><span class="token string">'conv6_2_mbox_priorbox'</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_priorbox'</span><span class="token punctuation">]</span> <span class="token operator">=</span> priorbox<span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv6_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 对conv7_2进行处理</span>
num_priors <span class="token operator">=</span> <span class="token number">6</span>
<span class="token comment"># 预测框的处理</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv7_2_mbox_loc'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv7_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_loc'</span><span class="token punctuation">]</span> <span class="token operator">=</span> x
net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_loc_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv7_2_mbox_loc_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,num_classes是所分的类</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> num_classes<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv7_2_mbox_conf'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv7_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> x
net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_conf_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv7_2_mbox_conf_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

priorbox <span class="token operator">=</span> PriorBox<span class="token punctuation">(</span>img_size<span class="token punctuation">,</span> <span class="token number">162.0</span><span class="token punctuation">,</span> max_size<span class="token operator">=</span><span class="token number">213.0</span><span class="token punctuation">,</span> aspect_ratios<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    variances<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    name<span class="token operator">=</span><span class="token string">'conv7_2_mbox_priorbox'</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_priorbox'</span><span class="token punctuation">]</span> <span class="token operator">=</span> priorbox<span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv7_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 对conv8_2进行处理</span>
num_priors <span class="token operator">=</span> <span class="token number">4</span>
<span class="token comment"># 预测框的处理</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv8_2_mbox_loc'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv8_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_loc'</span><span class="token punctuation">]</span> <span class="token operator">=</span> x
net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_loc_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv8_2_mbox_loc_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,num_classes是所分的类</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> num_classes<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv8_2_mbox_conf'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv8_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> x
net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_conf_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv8_2_mbox_conf_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

priorbox <span class="token operator">=</span> PriorBox<span class="token punctuation">(</span>img_size<span class="token punctuation">,</span> <span class="token number">213.0</span><span class="token punctuation">,</span> max_size<span class="token operator">=</span><span class="token number">264.0</span><span class="token punctuation">,</span> aspect_ratios<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    variances<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    name<span class="token operator">=</span><span class="token string">'conv8_2_mbox_priorbox'</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_priorbox'</span><span class="token punctuation">]</span> <span class="token operator">=</span> priorbox<span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv8_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 对conv9_2进行处理</span>
num_priors <span class="token operator">=</span> <span class="token number">4</span>
<span class="token comment"># 预测框的处理</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,4是x,y,h,w的调整</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv9_2_mbox_loc'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv9_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_loc'</span><span class="token punctuation">]</span> <span class="token operator">=</span> x
net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_loc_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv9_2_mbox_loc_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># num_priors表示每个网格点先验框的数量,num_classes是所分的类</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_priors <span class="token operator">*</span> num_classes<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'conv9_2_mbox_conf'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv9_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> x
net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_conf_flat'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Flatten<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'conv9_2_mbox_conf_flat'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

priorbox <span class="token operator">=</span> PriorBox<span class="token punctuation">(</span>img_size<span class="token punctuation">,</span> <span class="token number">264.0</span><span class="token punctuation">,</span> max_size<span class="token operator">=</span><span class="token number">315.0</span><span class="token punctuation">,</span> aspect_ratios<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    variances<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                    name<span class="token operator">=</span><span class="token string">'conv9_2_mbox_priorbox'</span><span class="token punctuation">)</span>

net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_priorbox'</span><span class="token punctuation">]</span> <span class="token operator">=</span> priorbox<span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'conv9_2'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 将所有结果进行堆叠</span>
net<span class="token punctuation">[</span><span class="token string">'mbox_loc'</span><span class="token punctuation">]</span> <span class="token operator">=</span> concatenate<span class="token punctuation">(</span><span class="token punctuation">[</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_loc_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                         net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_loc_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                         net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_loc_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                         net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_loc_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                         net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_loc_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                         net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_loc_flat'</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                        axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'mbox_loc'</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> concatenate<span class="token punctuation">(</span><span class="token punctuation">[</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_conf_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                          net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_conf_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                          net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_conf_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                          net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_conf_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                          net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_conf_flat'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                          net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_conf_flat'</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                         axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'mbox_conf'</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'mbox_priorbox'</span><span class="token punctuation">]</span> <span class="token operator">=</span> concatenate<span class="token punctuation">(</span><span class="token punctuation">[</span>net<span class="token punctuation">[</span><span class="token string">'conv4_3_norm_mbox_priorbox'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                              net<span class="token punctuation">[</span><span class="token string">'fc7_mbox_priorbox'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                              net<span class="token punctuation">[</span><span class="token string">'conv6_2_mbox_priorbox'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                              net<span class="token punctuation">[</span><span class="token string">'conv7_2_mbox_priorbox'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                              net<span class="token punctuation">[</span><span class="token string">'conv8_2_mbox_priorbox'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                              net<span class="token punctuation">[</span><span class="token string">'conv9_2_mbox_priorbox'</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                              axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'mbox_priorbox'</span><span class="token punctuation">)</span>

<span class="token keyword">if</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">'_keras_shape'</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    num_boxes <span class="token operator">=</span> net<span class="token punctuation">[</span><span class="token string">'mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">.</span>_keras_shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">//</span> <span class="token number">4</span>
<span class="token keyword">elif</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">'int_shape'</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    num_boxes <span class="token operator">=</span> K<span class="token punctuation">.</span>int_shape<span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">//</span> <span class="token number">4</span>
<span class="token comment"># 8732,4</span>
net<span class="token punctuation">[</span><span class="token string">'mbox_loc'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Reshape<span class="token punctuation">(</span><span class="token punctuation">(</span>num_boxes<span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'mbox_loc_final'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># 8732,21</span>
net<span class="token punctuation">[</span><span class="token string">'mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Reshape<span class="token punctuation">(</span><span class="token punctuation">(</span>num_boxes<span class="token punctuation">,</span> num_classes<span class="token punctuation">)</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'mbox_conf_logits'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
net<span class="token punctuation">[</span><span class="token string">'mbox_conf'</span><span class="token punctuation">]</span> <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'softmax'</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">'mbox_conf_final'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

net<span class="token punctuation">[</span><span class="token string">'predictions'</span><span class="token punctuation">]</span> <span class="token operator">=</span> concatenate<span class="token punctuation">(</span><span class="token punctuation">[</span>net<span class="token punctuation">[</span><span class="token string">'mbox_loc'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                           net<span class="token punctuation">[</span><span class="token string">'mbox_conf'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                           net<span class="token punctuation">[</span><span class="token string">'mbox_priorbox'</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                           axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'predictions'</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'predictions'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
model <span class="token operator">=</span> Model<span class="token punctuation">(</span>net<span class="token punctuation">[</span><span class="token string">'input'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> net<span class="token punctuation">[</span><span class="token string">'predictions'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147

3、预测结果的解码

我们通过对每一个特征层的处理,可以获得三个内容,分别是:

num_priors x 4的卷积 用于预测 该特征层上 每一个网格点上 每一个先验框的变化情况。**

num_priors x num_classes的卷积 用于预测 该特征层上 每一个网格点上 每一个预测框对应的种类。

每一个有效特征层对应的先验框对应着该特征层上 每一个网格点上 预先设定好的多个框。

我们利用 num_priors x 4的卷积每一个有效特征层对应的先验框 获得框的真实位置。

每一个有效特征层对应的先验框就是,如图所示的作用:
每一个有效特征层将整个图片分成与其长宽对应的网格,如conv4-3的特征层就是将整个图像分成38x38个网格;然后从每个网格中心建立多个先验框,如conv4-3的特征层就是建立了4个先验框;对于conv4-3的特征层来讲,整个图片被分成38x38个网格,每个网格中心对应4个先验框,一共包含了,38x38x4个,5776个先验框。
在这里插入图片描述
先验框虽然可以代表一定的框的位置信息与框的大小信息,但是其是有限的,无法表示任意情况,因此还需要调整,ssd利用num_priors x 4的卷积的结果对先验框进行调整。

num_priors x 4中的num_priors表示了这个网格点所包含的先验框数量,其中的4表示了x_offset、y_offset、h和w的调整情况。

x_offset与y_offset代表了真实框距离先验框中心的xy轴偏移情况。
h和w代表了真实框的宽与高相对于先验框的变化情况。

SSD解码过程就是将每个网格的中心点加上它对应的x_offset和y_offset,加完后的结果就是预测框的中心,然后再利用 先验框和h、w结合 计算出预测框的长和宽。这样就能得到整个预测框的位置了。

当然得到最终的预测结构后还要进行得分排序与非极大抑制筛选这一部分基本上是所有目标检测通用的部分。
1、取出每一类得分大于self.obj_threshold的框和得分。
2、利用框的位置和得分进行非极大抑制。

实现代码如下:

def decode_boxes(self, mbox_loc, mbox_priorbox, variances):
    # 获得先验框的宽与高
    prior_width = mbox_priorbox[:, 2] - mbox_priorbox[:, 0]
    prior_height = mbox_priorbox[:, 3] - mbox_priorbox[:, 1]
    # 获得先验框的中心点
    prior_center_x = 0.5 * (mbox_priorbox[:, 2] + mbox_priorbox[:, 0])
    prior_center_y = 0.5 * (mbox_priorbox[:, 3] + mbox_priorbox[:, 1])
<span class="token comment"># 真实框距离先验框中心的xy轴偏移情况</span>
decode_bbox_center_x <span class="token operator">=</span> mbox_loc<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> prior_width <span class="token operator">*</span> variances<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span>
decode_bbox_center_x <span class="token operator">+=</span> prior_center_x
decode_bbox_center_y <span class="token operator">=</span> mbox_loc<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">*</span> prior_height <span class="token operator">*</span> variances<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span>
decode_bbox_center_y <span class="token operator">+=</span> prior_center_y

<span class="token comment"># 真实框的宽与高的求取</span>
decode_bbox_width <span class="token operator">=</span> np<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>mbox_loc<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">*</span> variances<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
decode_bbox_width <span class="token operator">*=</span> prior_width
decode_bbox_height <span class="token operator">=</span> np<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>mbox_loc<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">*</span> variances<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
decode_bbox_height <span class="token operator">*=</span> prior_height

<span class="token comment"># 获取真实框的左上角与右下角</span>
decode_bbox_xmin <span class="token operator">=</span> decode_bbox_center_x <span class="token operator">-</span> <span class="token number">0.5</span> <span class="token operator">*</span> decode_bbox_width
decode_bbox_ymin <span class="token operator">=</span> decode_bbox_center_y <span class="token operator">-</span> <span class="token number">0.5</span> <span class="token operator">*</span> decode_bbox_height
decode_bbox_xmax <span class="token operator">=</span> decode_bbox_center_x <span class="token operator">+</span> <span class="token number">0.5</span> <span class="token operator">*</span> decode_bbox_width
decode_bbox_ymax <span class="token operator">=</span> decode_bbox_center_y <span class="token operator">+</span> <span class="token number">0.5</span> <span class="token operator">*</span> decode_bbox_height

<span class="token comment"># 真实框的左上角与右下角进行堆叠</span>
decode_bbox <span class="token operator">=</span> np<span class="token punctuation">.</span>concatenate<span class="token punctuation">(</span><span class="token punctuation">(</span>decode_bbox_xmin<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                decode_bbox_ymin<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                decode_bbox_xmax<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                decode_bbox_ymax<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
<span class="token comment"># 防止超出0与1</span>
decode_bbox <span class="token operator">=</span> np<span class="token punctuation">.</span>minimum<span class="token punctuation">(</span>np<span class="token punctuation">.</span>maximum<span class="token punctuation">(</span>decode_bbox<span class="token punctuation">,</span> <span class="token number">0.0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1.0</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> decode_bbox

def detection_out(self, predictions, background_label_id=0, keep_top_k=200,
confidence_threshold=0.5):
# 网络预测的结果
mbox_loc = predictions[:, :, :4]
# 0.1,0.1,0.2,0.2
variances = predictions[:, :, -4:]
# 先验框
mbox_priorbox = predictions[:, :, -8:-4]
# 置信度
mbox_conf = predictions[:, :, 4:-8]
results = []
# 对每一个特征层进行处理
for i in range(len(mbox_loc)):
results.append([])
decode_bbox = self.decode_boxes(mbox_loc[i], mbox_priorbox[i], variances[i])

    <span class="token keyword">for</span> c <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_classes<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token keyword">if</span> c <span class="token operator">==</span> background_label_id<span class="token punctuation">:</span>
            <span class="token keyword">continue</span>
        c_confs <span class="token operator">=</span> mbox_conf<span class="token punctuation">[</span>i<span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> c<span class="token punctuation">]</span>
        c_confs_m <span class="token operator">=</span> c_confs <span class="token operator">&gt;</span> confidence_threshold
        <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>c_confs<span class="token punctuation">[</span>c_confs_m<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">&gt;</span> <span class="token number">0</span><span class="token punctuation">:</span>
            <span class="token comment"># 取出得分高于confidence_threshold的框</span>
            boxes_to_process <span class="token operator">=</span> decode_bbox<span class="token punctuation">[</span>c_confs_m<span class="token punctuation">]</span>
            confs_to_process <span class="token operator">=</span> c_confs<span class="token punctuation">[</span>c_confs_m<span class="token punctuation">]</span>
            <span class="token comment"># 进行iou的非极大抑制</span>
            feed_dict <span class="token operator">=</span> <span class="token punctuation">{</span>self<span class="token punctuation">.</span>boxes<span class="token punctuation">:</span> boxes_to_process<span class="token punctuation">,</span>
                            self<span class="token punctuation">.</span>scores<span class="token punctuation">:</span> confs_to_process<span class="token punctuation">}</span>
            idx <span class="token operator">=</span> self<span class="token punctuation">.</span>sess<span class="token punctuation">.</span>run<span class="token punctuation">(</span>self<span class="token punctuation">.</span>nms<span class="token punctuation">,</span> feed_dict<span class="token operator">=</span>feed_dict<span class="token punctuation">)</span>
            <span class="token comment"># 取出在非极大抑制中效果较好的内容</span>
            good_boxes <span class="token operator">=</span> boxes_to_process<span class="token punctuation">[</span>idx<span class="token punctuation">]</span>
            confs <span class="token operator">=</span> confs_to_process<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span>
            <span class="token comment"># 将label、置信度、框的位置进行堆叠。</span>
            labels <span class="token operator">=</span> c <span class="token operator">*</span> np<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>idx<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
            c_pred <span class="token operator">=</span> np<span class="token punctuation">.</span>concatenate<span class="token punctuation">(</span><span class="token punctuation">(</span>labels<span class="token punctuation">,</span> confs<span class="token punctuation">,</span> good_boxes<span class="token punctuation">)</span><span class="token punctuation">,</span>
                                    axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
            <span class="token comment"># 添加进result里</span>
            results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>extend<span class="token punctuation">(</span>c_pred<span class="token punctuation">)</span>
    <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">&gt;</span> <span class="token number">0</span><span class="token punctuation">:</span>
        <span class="token comment"># 按照置信度进行排序</span>
        results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
        argsort <span class="token operator">=</span> np<span class="token punctuation">.</span>argsort<span class="token punctuation">(</span>results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span>
        results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span>argsort<span class="token punctuation">]</span>
        <span class="token comment"># 选出置信度最大的keep_top_k个</span>
        results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>keep_top_k<span class="token punctuation">]</span>
<span class="token keyword">return</span> results
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81

4、在原图上进行绘制

通过第三步,我们可以获得预测框在原图上的位置,而且这些预测框都是经过筛选的。这些筛选后的框可以直接绘制在图片上,就可以获得结果了。

二、训练部分

1、真实框的处理

从预测部分我们知道,每个特征层的预测结果,num_priors x 4的卷积 用于预测 该特征层上 每一个网格点上 每一个先验框的变化情况。

也就是说,我们直接利用ssd网络预测到的结果,并不是预测框在图片上的真实位置,需要解码才能得到真实位置。

而在训练的时候,我们需要计算loss函数,这个loss函数是相对于ssd网络的预测结果的。我们需要把图片输入到当前的ssd网络中,得到预测结果;同时还需要把真实框的信息,进行编码,这个编码是把真实框的位置信息格式转化为ssd预测结果的格式信息

也就是,我们需要找到 每一张用于训练的图片每一个真实框对应的先验框,并求出如果想要得到这样一个真实框,我们的预测结果应该是怎么样的。

从预测结果获得真实框的过程被称作解码,而从真实框获得预测结果的过程就是编码的过程。

因此我们只需要将解码过程逆过来就是编码过程了。

实现代码如下:

def encode_box(self, box, return_iou=True):
    iou = self.iou(box)
    encoded_box = np.zeros((self.num_priors, 4 + return_iou))
<span class="token comment"># 找到每一个真实框,重合程度较高的先验框</span>
assign_mask <span class="token operator">=</span> iou <span class="token operator">&gt;</span> self<span class="token punctuation">.</span>overlap_threshold
<span class="token keyword">if</span> <span class="token operator">not</span> assign_mask<span class="token punctuation">.</span><span class="token builtin">any</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    assign_mask<span class="token punctuation">[</span>iou<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token boolean">True</span>
<span class="token keyword">if</span> return_iou<span class="token punctuation">:</span>
    encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> iou<span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span>

<span class="token comment"># 找到对应的先验框</span>
assigned_priors <span class="token operator">=</span> self<span class="token punctuation">.</span>priors<span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span>
<span class="token comment"># 逆向编码,将真实框转化为ssd预测结果的格式</span>

<span class="token comment"># 先计算真实框的中心与长宽</span>
box_center <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span> box<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
box_wh <span class="token operator">=</span> box<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span> <span class="token operator">-</span> box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span>
<span class="token comment"># 再计算重合度较高的先验框的中心与长宽</span>
assigned_priors_center <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span>
                                assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
assigned_priors_wh <span class="token operator">=</span> <span class="token punctuation">(</span>assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">-</span>
                        assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 逆向求取ssd应该有的预测结果</span>
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> box_center <span class="token operator">-</span> assigned_priors_center
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">/=</span> assigned_priors_wh
<span class="token comment"># 除以0.1</span>
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">/=</span> assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">4</span><span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">]</span>

encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>log<span class="token punctuation">(</span>box_wh <span class="token operator">/</span> assigned_priors_wh<span class="token punctuation">)</span>
<span class="token comment"># 除以0.2</span>
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">/=</span> assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span>
<span class="token keyword">return</span> encoded_box<span class="token punctuation">.</span>ravel<span class="token punctuation">(</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

利用上述代码我们可以获得,真实框对应的所有的iou较大先验框,并计算了真实框对应的所有iou较大的先验框应该有的预测结果。

在训练的时候我们只需要选择iou最大的先验框就行了,这个iou最大的先验框就是我们用来预测这个真实框所用的先验框。

因此我们还要经过一次筛选,将上述代码获得的真实框对应的所有的iou较大先验框的预测结果中,iou最大的那个筛选出来。

通过assign_boxes我们就获得了,输入进来的这张图片,应该有的预测结果是什么样子的。

实现代码如下:

def assign_boxes(self, boxes):
    assignment = np.zeros((self.num_priors, 4 + self.num_classes + 8))
    assignment[:, 4] = 1.0
    if len(boxes) == 0:
        return assignment
    # 对每一个真实框都进行iou计算
    encoded_boxes = np.apply_along_axis(self.encode_box, 1, boxes[:, :4])
    # 每一个真实框的编码后的值,和iou
    encoded_boxes = encoded_boxes.reshape(-1, self.num_priors, 5)
<span class="token comment"># 取重合程度最大的先验框,并且获取这个先验框的index</span>
best_iou <span class="token operator">=</span> encoded_boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
best_iou_idx <span class="token operator">=</span> encoded_boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
best_iou_mask <span class="token operator">=</span> best_iou <span class="token operator">&gt;</span> <span class="token number">0</span>
best_iou_idx <span class="token operator">=</span> best_iou_idx<span class="token punctuation">[</span>best_iou_mask<span class="token punctuation">]</span>

assign_num <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>best_iou_idx<span class="token punctuation">)</span>
<span class="token comment"># 保留重合程度最大的先验框的应该有的预测结果</span>
encoded_boxes <span class="token operator">=</span> encoded_boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> best_iou_mask<span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span>
assignment<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>best_iou_mask<span class="token punctuation">]</span> <span class="token operator">=</span> encoded_boxes<span class="token punctuation">[</span>best_iou_idx<span class="token punctuation">,</span>np<span class="token punctuation">.</span>arange<span class="token punctuation">(</span>assign_num<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span>
<span class="token comment"># 4代表为背景的概率,为0</span>
assignment<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>best_iou_mask<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">0</span>
assignment<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">[</span>best_iou_mask<span class="token punctuation">]</span> <span class="token operator">=</span> boxes<span class="token punctuation">[</span>best_iou_idx<span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">:</span><span class="token punctuation">]</span>
assignment<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">[</span>best_iou_mask<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">1</span>
<span class="token comment"># 通过assign_boxes我们就获得了,输入进来的这张图片,应该有的预测结果是什么样子的</span>
<span class="token keyword">return</span> assignment
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

2、利用处理完的真实框与对应图片的预测结果计算loss

loss的计算分为三个部分:
1、获取所有正标签的框的预测结果的回归loss。
2、获取所有正标签的种类的预测结果的交叉熵loss。
3、获取一定负标签的种类的预测结果的交叉熵loss。

由于在ssd的训练过程中,正负样本极其不平衡,即 存在对应真实框的先验框可能只有2~3个,但是不存在对应真实框的负样本却有几千个,这就会导致负样本的loss值极大,因此我们可以考虑减少负样本的选取,对于ssd的训练来讲,常见的情况是取三倍正样本数量的负样本用于训练。这个三倍呢,也可以修改,调整成自己喜欢的数字。

实现代码如下:

class MultiboxLoss(object):
    def __init__(self, num_classes, alpha=1.0, neg_pos_ratio=3.0,
                 background_label_id=0, negatives_for_hard=100.0):
        self.num_classes = num_classes
        self.alpha = alpha
        self.neg_pos_ratio = neg_pos_ratio
        if background_label_id != 0:
            raise Exception('Only 0 as background label id is supported')
        self.background_label_id = background_label_id
        self.negatives_for_hard = negatives_for_hard
<span class="token keyword">def</span> <span class="token function">_l1_smooth_loss</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> y_true<span class="token punctuation">,</span> y_pred<span class="token punctuation">)</span><span class="token punctuation">:</span>
    abs_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span><span class="token builtin">abs</span><span class="token punctuation">(</span>y_true <span class="token operator">-</span> y_pred<span class="token punctuation">)</span>
    sq_loss <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>y_true <span class="token operator">-</span> y_pred<span class="token punctuation">)</span><span class="token operator">**</span><span class="token number">2</span>
    l1_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>where<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>less<span class="token punctuation">(</span>abs_loss<span class="token punctuation">,</span> <span class="token number">1.0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> sq_loss<span class="token punctuation">,</span> abs_loss <span class="token operator">-</span> <span class="token number">0.5</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>l1_loss<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">_softmax_loss</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> y_true<span class="token punctuation">,</span> y_pred<span class="token punctuation">)</span><span class="token punctuation">:</span>
    y_pred <span class="token operator">=</span> tf<span class="token punctuation">.</span>maximum<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>minimum<span class="token punctuation">(</span>y_pred<span class="token punctuation">,</span> <span class="token number">1</span> <span class="token operator">-</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">15</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">15</span><span class="token punctuation">)</span>
    softmax_loss <span class="token operator">=</span> <span class="token operator">-</span>tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>y_true <span class="token operator">*</span> tf<span class="token punctuation">.</span>log<span class="token punctuation">(</span>y_pred<span class="token punctuation">)</span><span class="token punctuation">,</span>
                                  axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> softmax_loss

<span class="token keyword">def</span> <span class="token function">compute_loss</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> y_true<span class="token punctuation">,</span> y_pred<span class="token punctuation">)</span><span class="token punctuation">:</span>
    batch_size <span class="token operator">=</span> tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>y_true<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
    num_boxes <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_float<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>shape<span class="token punctuation">(</span>y_true<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

    <span class="token comment"># 计算所有的loss</span>
    <span class="token comment"># 分类的loss</span>
    <span class="token comment"># batch_size,8732,21 -&gt; batch_size,8732</span>
    conf_loss <span class="token operator">=</span> self<span class="token punctuation">.</span>_softmax_loss<span class="token punctuation">(</span>y_true<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                   y_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 框的位置的loss</span>
    <span class="token comment"># batch_size,8732,4 -&gt; batch_size,8732</span>
    loc_loss <span class="token operator">=</span> self<span class="token punctuation">.</span>_l1_smooth_loss<span class="token punctuation">(</span>y_true<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                    y_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

    <span class="token comment"># 获取所有的正标签的loss</span>
    <span class="token comment"># 每一张图的pos的个数</span>
    num_pos <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>y_true<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token comment"># 每一张图的pos_loc_loss</span>
    pos_loc_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>loc_loss <span class="token operator">*</span> y_true<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                 axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token comment"># 每一张图的pos_conf_loss</span>
    pos_conf_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>conf_loss <span class="token operator">*</span> y_true<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                  axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>

    <span class="token comment"># 获取一定的负样本</span>
    num_neg <span class="token operator">=</span> tf<span class="token punctuation">.</span>minimum<span class="token punctuation">(</span>self<span class="token punctuation">.</span>neg_pos_ratio <span class="token operator">*</span> num_pos<span class="token punctuation">,</span>
                         num_boxes <span class="token operator">-</span> num_pos<span class="token punctuation">)</span>

    <span class="token comment"># 找到了哪些值是大于0的</span>
    pos_num_neg_mask <span class="token operator">=</span> tf<span class="token punctuation">.</span>greater<span class="token punctuation">(</span>num_neg<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span>
    <span class="token comment"># 获得一个1.0</span>
    has_min <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_float<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reduce_any<span class="token punctuation">(</span>pos_num_neg_mask<span class="token punctuation">)</span><span class="token punctuation">)</span>
    num_neg <span class="token operator">=</span> tf<span class="token punctuation">.</span>concat<span class="token punctuation">(</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span>values<span class="token operator">=</span><span class="token punctuation">[</span>num_neg<span class="token punctuation">,</span>
                            <span class="token punctuation">[</span><span class="token punctuation">(</span><span class="token number">1</span> <span class="token operator">-</span> has_min<span class="token punctuation">)</span> <span class="token operator">*</span> self<span class="token punctuation">.</span>negatives_for_hard<span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    <span class="token comment"># 求平均每个图片要取多少个负样本</span>
    num_neg_batch <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_mean<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>boolean_mask<span class="token punctuation">(</span>num_neg<span class="token punctuation">,</span>
                                                  tf<span class="token punctuation">.</span>greater<span class="token punctuation">(</span>num_neg<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    num_neg_batch <span class="token operator">=</span> tf<span class="token punctuation">.</span>to_int32<span class="token punctuation">(</span>num_neg_batch<span class="token punctuation">)</span>

    <span class="token comment"># conf的起始</span>
    confs_start <span class="token operator">=</span> <span class="token number">4</span> <span class="token operator">+</span> self<span class="token punctuation">.</span>background_label_id <span class="token operator">+</span> <span class="token number">1</span>
    <span class="token comment"># conf的结束</span>
    confs_end <span class="token operator">=</span> confs_start <span class="token operator">+</span> self<span class="token punctuation">.</span>num_classes <span class="token operator">-</span> <span class="token number">1</span>

    <span class="token comment"># 找到实际上在该位置不应该有预测结果的框,求他们最大的置信度。</span>
    max_confs <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_max<span class="token punctuation">(</span>y_pred<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> confs_start<span class="token punctuation">:</span>confs_end<span class="token punctuation">]</span><span class="token punctuation">,</span>
                              axis<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>
    
    <span class="token comment"># 取top_k个置信度,作为负样本</span>
    _<span class="token punctuation">,</span> indices <span class="token operator">=</span> tf<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>top_k<span class="token punctuation">(</span>max_confs <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">1</span> <span class="token operator">-</span> y_true<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                             k<span class="token operator">=</span>num_neg_batch<span class="token punctuation">)</span>

    <span class="token comment"># 找到其在1维上的索引</span>
    batch_idx <span class="token operator">=</span> tf<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>tf<span class="token punctuation">.</span><span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> batch_size<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    batch_idx <span class="token operator">=</span> tf<span class="token punctuation">.</span>tile<span class="token punctuation">(</span>batch_idx<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> num_neg_batch<span class="token punctuation">)</span><span class="token punctuation">)</span>
    full_indices <span class="token operator">=</span> <span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>batch_idx<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> tf<span class="token punctuation">.</span>to_int32<span class="token punctuation">(</span>num_boxes<span class="token punctuation">)</span> <span class="token operator">+</span>
                    tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>indices<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    

    neg_conf_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>gather<span class="token punctuation">(</span>tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>conf_loss<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                              full_indices<span class="token punctuation">)</span>
    neg_conf_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>neg_conf_loss<span class="token punctuation">,</span>
                               <span class="token punctuation">[</span>batch_size<span class="token punctuation">,</span> num_neg_batch<span class="token punctuation">]</span><span class="token punctuation">)</span>
    neg_conf_loss <span class="token operator">=</span> tf<span class="token punctuation">.</span>reduce_sum<span class="token punctuation">(</span>neg_conf_loss<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>

    <span class="token comment"># 求loss总和</span>
    total_loss <span class="token operator">=</span> K<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>pos_conf_loss <span class="token operator">+</span> neg_conf_loss<span class="token punctuation">)</span><span class="token operator">/</span>K<span class="token punctuation">.</span>cast<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span>K<span class="token punctuation">.</span>dtype<span class="token punctuation">(</span>pos_conf_loss<span class="token punctuation">)</span><span class="token punctuation">)</span>

    total_loss <span class="token operator">+=</span>  K<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>alpha <span class="token operator">*</span> pos_loc_loss<span class="token punctuation">)</span><span class="token operator">/</span>K<span class="token punctuation">.</span>cast<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span>K<span class="token punctuation">.</span>dtype<span class="token punctuation">(</span>pos_loc_loss<span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> total_loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93

训练自己的ssd模型

ssd整体的文件夹构架如下:
在这里插入图片描述
本文使用VOC格式进行训练。
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
在这里插入图片描述
训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
在这里插入图片描述
在训练前利用voc2ssd.py文件生成对应的txt。
在这里插入图片描述
再运行根目录下的voc_annotation.py,运行前需要将classes改成你自己的classes。

classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]

 
 
  • 1

在这里插入图片描述
就会生成对应的2007_train.txt,每一行对应其图片位置及其真实框的位置。
在这里插入图片描述
在训练前需要修改model_data里面的voc_classes.txt文件,需要将classes改成你自己的classes。
在这里插入图片描述
运行train.py即可开始训练。
在这里插入图片描述

                                </div>
            <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-60ecaf1f42.css" rel="stylesheet">
                            </div>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值