(大佬)Keras搭建Faster-RCNN目标检测平台

原文链接:https://blog.csdn.net/weixin_44791964/article/details/104451667

学习前言

最近对实例分割感兴趣了,不过实例分割MaskRCNN是基于FasterRCNN的,之前学了非常多的One-Stage的目标检测算法,对FasterRCNN并不感兴趣,这次我们来学学FasterRCNN。
在这里插入图片描述

什么是FasterRCNN目标检测算法

在这里插入图片描述
Faster-RCNN是一个非常有效的目标检测算法,虽然是一个比较早的论文, 但它至今仍是许多目标检测算法的基础。

Faster-RCNN作为一种two-stage的算法,与one-stage的算法相比,two-stage的算法更加复杂且速度较慢,但是检测精度会更高。

事实上也确实是这样,Faster-RCNN的检测效果非常不错,但是检测速度与训练速度有待提高。

源码下载

https://github.com/bubbliiiing/faster-rcnn-keras
喜欢的可以点个star噢。

Faster-RCNN实现思路

一、预测部分

1、主干网络介绍

在这里插入图片描述
Faster-RCNN可以采用多种的主干特征提取网络,常用的有VGG,Resnet,Xception等等,本文采用的是Resnet网络,关于Resnet的介绍大家可以看我的另外一篇博客https://blog.csdn.net/weixin_44791964/article/details/102790260

FasterRcnn对输入进来的图片尺寸没有固定,但是一般会把输入进来的图片短边固定成600,如输入一张1200x1800的图片,会把图片不失真的resize到600x900上。

ResNet50有两个基本的块,分别名为Conv Block和Identity Block,其中Conv Block输入和输出的维度是不一样的,所以不能连续串联,它的作用是改变网络的维度;Identity Block输入维度和输出维度相同,可以串联,用于加深网络的。
Conv Block的结构如下:
在这里插入图片描述
Identity Block的结构如下:
在这里插入图片描述
这两个都是残差网络结构。

Faster-RCNN的主干特征提取网络部分只包含了长宽压缩了四次的内容,第五次压缩后的内容在ROI中使用。即Faster-RCNN在主干特征提取网络所用的网络层如图所示。
以输入的图片为600x600为例,shape变化如下:
在这里插入图片描述
最后一层的输出就是公用特征层。

实现代码:

def identity_block(input_tensor, kernel_size, filters, stage, block):
filters1<span class="token punctuation">,</span> filters2<span class="token punctuation">,</span> filters3 <span class="token operator">=</span> filters

conv_name_base <span class="token operator">=</span> <span class="token string">'res'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>
bn_name_base <span class="token operator">=</span> <span class="token string">'bn'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>

x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>filters1<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>input_tensor<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>filters2<span class="token punctuation">,</span> kernel_size<span class="token punctuation">,</span>padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>filters3<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> layers<span class="token punctuation">.</span>add<span class="token punctuation">(</span><span class="token punctuation">[</span>x<span class="token punctuation">,</span> input_tensor<span class="token punctuation">]</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
<span class="token keyword">return</span> x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

filters1<span class="token punctuation">,</span> filters2<span class="token punctuation">,</span> filters3 <span class="token operator">=</span> filters

conv_name_base <span class="token operator">=</span> <span class="token string">'res'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>
bn_name_base <span class="token operator">=</span> <span class="token string">'bn'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>

x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>filters1<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span>strides<span class="token punctuation">,</span>
           name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>input_tensor<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>filters2<span class="token punctuation">,</span> kernel_size<span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span>
           name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>filters3<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

shortcut <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>filters3<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span>strides<span class="token punctuation">,</span>
                  name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>input_tensor<span class="token punctuation">)</span>
shortcut <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>shortcut<span class="token punctuation">)</span>

x <span class="token operator">=</span> layers<span class="token punctuation">.</span>add<span class="token punctuation">(</span><span class="token punctuation">[</span>x<span class="token punctuation">,</span> shortcut<span class="token punctuation">]</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
<span class="token keyword">return</span> x

def ResNet50(inputs):

img_input <span class="token operator">=</span> inputs

x <span class="token operator">=</span> ZeroPadding2D<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>img_input<span class="token punctuation">)</span>
x <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">7</span><span class="token punctuation">,</span> <span class="token number">7</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'conv1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'bn_conv1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> MaxPooling2D<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">"same"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'a'</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'b'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'c'</span><span class="token punctuation">)</span>


x <span class="token operator">=</span> conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'a'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'b'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'c'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'d'</span><span class="token punctuation">)</span>

x <span class="token operator">=</span> conv_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'a'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'b'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'c'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'d'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'e'</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> identity_block<span class="token punctuation">(</span>x<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">'f'</span><span class="token punctuation">)</span>

<span class="token keyword">return</span> x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81

2、获得Proposal建议框

在这里插入图片描述
获得的公用特征层在图像中就是Feature Map,其有两个应用,一个是和ROIPooling结合使用、另一个是进行一次3x3的卷积后,进行一个9通道的1x1卷积,还有一个36通道的1x1卷积。

在Faster-RCNN中,num_priors也就是先验框的数量就是9,所以两个1x1卷积的结果实际上也就是:

9 x 4的卷积 用于预测 公用特征层上 每一个网格点上 每一个先验框的变化情况。(为什么说是变化情况呢,这是因为Faster-RCNN的预测结果需要结合先验框获得预测框,预测结果就是先验框的变化情况。)

9 x 1的卷积 用于预测 公用特征层上 每一个网格点上 每一个预测框内部是否包含了物体。

当我们输入的图片的shape是600x600x3的时候,公用特征层的shape就是38x38x1024,相当于把输入进来的图像分割成38x38的网格,然后每个网格存在9个先验框,这些先验框有不同的大小,在图像上密密麻麻。

9 x 4的卷积的结果会对这些先验框进行调整,获得一个新的框。
9 x 1的卷积会判断上述获得的新框是否包含物体。

到这里我们可以获得了一些有用的框,这些框会利用9 x 1的卷积判断是否存在物体。

到此位置还只是粗略的一个框的获取,也就是一个建议框。然后我们会在建议框里面继续找东西。

实现代码为:

def get_rpn(base_layers, num_anchors):
    x = Conv2D(512, (3, 3), padding='same', activation='relu', kernel_initializer='normal', name='rpn_conv1')(base_layers)
x_class <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_anchors<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'sigmoid'</span><span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'uniform'</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'rpn_out_class'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x_regr <span class="token operator">=</span> Conv2D<span class="token punctuation">(</span>num_anchors <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'linear'</span><span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'zero'</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'rpn_out_regress'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x_class <span class="token operator">=</span> Reshape<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">"classification"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x_class<span class="token punctuation">)</span>
x_regr <span class="token operator">=</span> Reshape<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">,</span>name<span class="token operator">=</span><span class="token string">"regression"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x_regr<span class="token punctuation">)</span>
<span class="token keyword">return</span> <span class="token punctuation">[</span>x_class<span class="token punctuation">,</span> x_regr<span class="token punctuation">,</span> base_layers<span class="token punctuation">]</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

3、Proposal建议框的解码

通过第二步我们获得了38x38x9个先验框的预测结果。预测结果包含两部分。

9 x 4的卷积 用于预测 公用特征层上 每一个网格点上 每一个先验框的变化情况。**

9 x 1的卷积 用于预测 公用特征层上 每一个网格点上 每一个预测框内部是否包含了物体。

相当于就是将整个图像分成38x38个网格;然后从每个网格中心建立9个先验框,一共38x38x9个,12996个先验框。

当输入图像shape不同时,先验框的数量也会发生改变。
在这里插入图片描述
先验框虽然可以代表一定的框的位置信息与框的大小信息,但是其是有限的,无法表示任意情况,因此还需要调整。

9 x 4中的9表示了这个网格点所包含的先验框数量,其中的4表示了框的中心与长宽的调整情况。

实现代码如下:


    def decode_boxes(self, mbox_loc, mbox_priorbox):
        # 获得先验框的宽与高
        prior_width = mbox_priorbox[:, 2] - mbox_priorbox[:, 0]
        prior_height = mbox_priorbox[:, 3] - mbox_priorbox[:, 1]
    <span class="token comment"># 获得先验框的中心点</span>
    prior_center_x <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>mbox_priorbox<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span> mbox_priorbox<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
    prior_center_y <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>mbox_priorbox<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">+</span> mbox_priorbox<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

    <span class="token comment"># 真实框距离先验框中心的xy轴偏移情况</span>
    decode_bbox_center_x <span class="token operator">=</span> mbox_loc<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> prior_width <span class="token operator">/</span> <span class="token number">4</span>
    decode_bbox_center_x <span class="token operator">+=</span> prior_center_x
    decode_bbox_center_y <span class="token operator">=</span> mbox_loc<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">*</span> prior_height <span class="token operator">/</span> <span class="token number">4</span>
    decode_bbox_center_y <span class="token operator">+=</span> prior_center_y
    
    <span class="token comment"># 真实框的宽与高的求取</span>
    decode_bbox_width <span class="token operator">=</span> np<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>mbox_loc<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">/</span> <span class="token number">4</span><span class="token punctuation">)</span>
    decode_bbox_width <span class="token operator">*=</span> prior_width
    decode_bbox_height <span class="token operator">=</span> np<span class="token punctuation">.</span>exp<span class="token punctuation">(</span>mbox_loc<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">/</span><span class="token number">4</span><span class="token punctuation">)</span>
    decode_bbox_height <span class="token operator">*=</span> prior_height

    <span class="token comment"># 获取真实框的左上角与右下角</span>
    decode_bbox_xmin <span class="token operator">=</span> decode_bbox_center_x <span class="token operator">-</span> <span class="token number">0.5</span> <span class="token operator">*</span> decode_bbox_width
    decode_bbox_ymin <span class="token operator">=</span> decode_bbox_center_y <span class="token operator">-</span> <span class="token number">0.5</span> <span class="token operator">*</span> decode_bbox_height
    decode_bbox_xmax <span class="token operator">=</span> decode_bbox_center_x <span class="token operator">+</span> <span class="token number">0.5</span> <span class="token operator">*</span> decode_bbox_width
    decode_bbox_ymax <span class="token operator">=</span> decode_bbox_center_y <span class="token operator">+</span> <span class="token number">0.5</span> <span class="token operator">*</span> decode_bbox_height

    <span class="token comment"># 真实框的左上角与右下角进行堆叠</span>
    decode_bbox <span class="token operator">=</span> np<span class="token punctuation">.</span>concatenate<span class="token punctuation">(</span><span class="token punctuation">(</span>decode_bbox_xmin<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                  decode_bbox_ymin<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                  decode_bbox_xmax<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
                                  decode_bbox_ymax<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token comment"># 防止超出0与1</span>
    decode_bbox <span class="token operator">=</span> np<span class="token punctuation">.</span>minimum<span class="token punctuation">(</span>np<span class="token punctuation">.</span>maximum<span class="token punctuation">(</span>decode_bbox<span class="token punctuation">,</span> <span class="token number">0.0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1.0</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> decode_bbox

<span class="token keyword">def</span> <span class="token function">detection_out</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> predictions<span class="token punctuation">,</span> mbox_priorbox<span class="token punctuation">,</span> num_classes<span class="token punctuation">,</span> keep_top_k<span class="token operator">=</span><span class="token number">300</span><span class="token punctuation">,</span>
                    confidence_threshold<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    
    <span class="token comment"># 网络预测的结果</span>
    <span class="token comment"># 置信度</span>
    mbox_conf <span class="token operator">=</span> predictions<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
    mbox_loc <span class="token operator">=</span> predictions<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span>
    <span class="token comment"># 先验框</span>
    mbox_priorbox <span class="token operator">=</span> mbox_priorbox
    results <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
    <span class="token comment"># 对每一个图片进行处理</span>
    <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>mbox_loc<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
        results<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
        decode_bbox <span class="token operator">=</span> self<span class="token punctuation">.</span>decode_boxes<span class="token punctuation">(</span>mbox_loc<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">,</span> mbox_priorbox<span class="token punctuation">)</span>
        <span class="token keyword">for</span> c <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>num_classes<span class="token punctuation">)</span><span class="token punctuation">:</span>
            c_confs <span class="token operator">=</span> mbox_conf<span class="token punctuation">[</span>i<span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> c<span class="token punctuation">]</span>
            c_confs_m <span class="token operator">=</span> c_confs <span class="token operator">&gt;</span> confidence_threshold
            <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>c_confs<span class="token punctuation">[</span>c_confs_m<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">&gt;</span> <span class="token number">0</span><span class="token punctuation">:</span>
                <span class="token comment"># 取出得分高于confidence_threshold的框</span>
                boxes_to_process <span class="token operator">=</span> decode_bbox<span class="token punctuation">[</span>c_confs_m<span class="token punctuation">]</span>
                confs_to_process <span class="token operator">=</span> c_confs<span class="token punctuation">[</span>c_confs_m<span class="token punctuation">]</span>
                <span class="token comment"># 进行iou的非极大抑制</span>
                feed_dict <span class="token operator">=</span> <span class="token punctuation">{</span>self<span class="token punctuation">.</span>boxes<span class="token punctuation">:</span> boxes_to_process<span class="token punctuation">,</span>
                                self<span class="token punctuation">.</span>scores<span class="token punctuation">:</span> confs_to_process<span class="token punctuation">}</span>
                idx <span class="token operator">=</span> self<span class="token punctuation">.</span>sess<span class="token punctuation">.</span>run<span class="token punctuation">(</span>self<span class="token punctuation">.</span>nms<span class="token punctuation">,</span> feed_dict<span class="token operator">=</span>feed_dict<span class="token punctuation">)</span>
                <span class="token comment"># 取出在非极大抑制中效果较好的内容</span>
                good_boxes <span class="token operator">=</span> boxes_to_process<span class="token punctuation">[</span>idx<span class="token punctuation">]</span>
                confs <span class="token operator">=</span> confs_to_process<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">]</span>
                <span class="token comment"># 将label、置信度、框的位置进行堆叠。</span>
                labels <span class="token operator">=</span> c <span class="token operator">*</span> np<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>idx<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
                c_pred <span class="token operator">=</span> np<span class="token punctuation">.</span>concatenate<span class="token punctuation">(</span><span class="token punctuation">(</span>labels<span class="token punctuation">,</span> confs<span class="token punctuation">,</span> good_boxes<span class="token punctuation">)</span><span class="token punctuation">,</span>
                                        axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
                <span class="token comment"># 添加进result里</span>
                results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>extend<span class="token punctuation">(</span>c_pred<span class="token punctuation">)</span>

        <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">&gt;</span> <span class="token number">0</span><span class="token punctuation">:</span>
            <span class="token comment"># 按照置信度进行排序</span>
            results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
            argsort <span class="token operator">=</span> np<span class="token punctuation">.</span>argsort<span class="token punctuation">(</span>results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span>
            results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span>argsort<span class="token punctuation">]</span>
            <span class="token comment"># 选出置信度最大的keep_top_k个</span>
            results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> results<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>keep_top_k<span class="token punctuation">]</span>
    <span class="token comment"># 获得,在所有预测结果里面,置信度比较高的框</span>
    <span class="token comment"># 还有,利用先验框和Faster-RCNN的预测结果,处理获得了真实框(预测框)的位置</span>
    <span class="token keyword">return</span> results
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83

4、对Proposal建议框加以利用(RoiPoolingConv)

在这里插入图片描述
让我们对建议框有一个整体的理解:
事实上建议框就是对图片哪一个区域有物体存在进行初步筛选。

通过主干特征提取网络,我们可以获得一个公用特征层,当输入图片为600x600x3的时候,它的shape是38x38x1024,然后建议框会对这个公用特征层进行截取。

其实公用特征层里面的38x38对应着图片里的38x38个区域,38x38中的每一个点相当于这个区域内部所有特征的浓缩。

建议框会对这38x38个区域进行截取,也就是认为这些区域里存在目标,然后将截取的结果进行resize,resize到14x14x1024的大小。

每次输入的建议框的数量默认情况是32。

然后再对每个建议框再进行Resnet原有的第五次压缩。压缩完后进行一个平均池化,再进行一个Flatten,最后分别进行一个num_classes的全连接和(num_classes-1)x4全连接。

num_classes的全连接用于对最后获得的框进行分类,(num_classes-1)x4全连接用于对相应的建议框进行调整,之所以-1是不包括被认定为背景的框。

通过这些操作,我们可以获得所有建议框的调整情况,和这个建议框调整后框内物体的类别。

事实上,在上一步获得的建议框就是ROI的先验框

对Proposal建议框加以利用的过程与shape变化如图所示:
在这里插入图片描述
建议框调整后的结果就是最终的预测结果了,可以在图上进行绘画了。

class RoiPoolingConv(Layer):
    def __init__(self, pool_size, num_rois, **kwargs):
        self.dim_ordering = K.image_dim_ordering()
        assert self.dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
        self.pool_size = pool_size
        self.num_rois = num_rois
        super(RoiPoolingConv, self).__init__(**kwargs)
    def build(self, input_shape):
        self.nb_channels = input_shape[0][3]
    def compute_output_shape(self, input_shape):
        return None, self.num_rois, self.pool_size, self.pool_size, self.nb_channels
    def call(self, x, mask=None):
        assert(len(x) == 2)
        img = x[0]
        rois = x[1]
        outputs = []
        for roi_idx in range(self.num_rois):
            x = rois[0, roi_idx, 0]
            y = rois[0, roi_idx, 1]
            w = rois[0, roi_idx, 2]
            h = rois[0, roi_idx, 3]
            x = K.cast(x, 'int32')
            y = K.cast(y, 'int32')
            w = K.cast(w, 'int32')
            h = K.cast(h, 'int32')
            rs = tf.image.resize_images(img[:, y:y+h, x:x+w, :], (self.pool_size, self.pool_size))
            outputs.append(rs)
        final_output = K.concatenate(outputs, axis=0)
        final_output = K.reshape(final_output, (1, self.num_rois, self.pool_size, self.pool_size, self.nb_channels))
        final_output = K.permute_dimensions(final_output, (0, 1, 2, 3, 4))
        return final_output

def identity_block_td(input_tensor, kernel_size, filters, stage, block, trainable=True):
nb_filter1, nb_filter2, nb_filter3 = filters
if K.image_dim_ordering() == ‘tf’:
bn_axis = 3
else:
bn_axis = 1

conv_name_base <span class="token operator">=</span> <span class="token string">'res'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>
bn_name_base <span class="token operator">=</span> <span class="token string">'bn'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>

x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>Conv2D<span class="token punctuation">(</span>nb_filter1<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> trainable<span class="token operator">=</span>trainable<span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'normal'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>input_tensor<span class="token punctuation">)</span>
x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>bn_axis<span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>Conv2D<span class="token punctuation">(</span>nb_filter2<span class="token punctuation">,</span> <span class="token punctuation">(</span>kernel_size<span class="token punctuation">,</span> kernel_size<span class="token punctuation">)</span><span class="token punctuation">,</span> trainable<span class="token operator">=</span>trainable<span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'normal'</span><span class="token punctuation">,</span>padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>bn_axis<span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>Conv2D<span class="token punctuation">(</span>nb_filter3<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> trainable<span class="token operator">=</span>trainable<span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'normal'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>bn_axis<span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> Add<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>x<span class="token punctuation">,</span> input_tensor<span class="token punctuation">]</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

<span class="token keyword">return</span> x

def conv_block_td(input_tensor, kernel_size, filters, stage, block, input_shape, strides=(2, 2), trainable=True):
nb_filter1, nb_filter2, nb_filter3 = filters
if K.image_dim_ordering() == ‘tf’:
bn_axis = 3
else:
bn_axis = 1

conv_name_base <span class="token operator">=</span> <span class="token string">'res'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>
bn_name_base <span class="token operator">=</span> <span class="token string">'bn'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>

x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>Conv2D<span class="token punctuation">(</span>nb_filter1<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span>strides<span class="token punctuation">,</span> trainable<span class="token operator">=</span>trainable<span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'normal'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> input_shape<span class="token operator">=</span>input_shape<span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>input_tensor<span class="token punctuation">)</span>
x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>bn_axis<span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>Conv2D<span class="token punctuation">(</span>nb_filter2<span class="token punctuation">,</span> <span class="token punctuation">(</span>kernel_size<span class="token punctuation">,</span> kernel_size<span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span> trainable<span class="token operator">=</span>trainable<span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'normal'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>bn_axis<span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>Conv2D<span class="token punctuation">(</span>nb_filter3<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'normal'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">,</span> trainable<span class="token operator">=</span>trainable<span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
x <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>bn_axis<span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>

shortcut <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>Conv2D<span class="token punctuation">(</span>nb_filter3<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span>strides<span class="token punctuation">,</span> trainable<span class="token operator">=</span>trainable<span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token string">'normal'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>conv_name_base <span class="token operator">+</span> <span class="token string">'1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>input_tensor<span class="token punctuation">)</span>
shortcut <span class="token operator">=</span> TimeDistributed<span class="token punctuation">(</span>BatchNormalization<span class="token punctuation">(</span>axis<span class="token operator">=</span>bn_axis<span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">'1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>shortcut<span class="token punctuation">)</span>

x <span class="token operator">=</span> Add<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>x<span class="token punctuation">,</span> shortcut<span class="token punctuation">]</span><span class="token punctuation">)</span>
x <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span>
<span class="token keyword">return</span> x

def classifier_layers(x, input_shape, trainable=False):
x = conv_block_td(x, 3, [512, 512, 2048], stage=5, block=‘a’, input_shape=input_shape, strides=(2, 2), trainable=trainable)
x = identity_block_td(x, 3, [512, 512, 2048], stage=5, block=‘b’, trainable=trainable)
x = identity_block_td(x, 3, [512, 512, 2048], stage=5, block=‘c’, trainable=trainable)
x = TimeDistributed(AveragePooling2D((7, 7)), name=‘avg_pool’)(x)

<span class="token keyword">return</span> x

def get_classifier(base_layers, input_rois, num_rois, nb_classes=21, trainable=False):
pooling_regions = 14
input_shape = (num_rois, 14, 14, 1024)
out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers, input_rois])
out = classifier_layers(out_roi_pool, input_shape=input_shape, trainable=True)
out = TimeDistributed(Flatten())(out)
out_class = TimeDistributed(Dense(nb_classes, activation=‘softmax’, kernel_initializer=‘zero’), name=‘dense_class_{}’.format(nb_classes))(out)
out_regr = TimeDistributed(Dense(4 * (nb_classes-1), activation=‘linear’, kernel_initializer=‘zero’), name=‘dense_regress_{}’.format(nb_classes))(out)
return [out_class, out_regr]

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104

5、在原图上进行绘制

在第四步的结尾,我们对建议框进行再一次进行解码后,我们可以获得预测框在原图上的位置,而且这些预测框都是经过筛选的。这些筛选后的框可以直接绘制在图片上,就可以获得结果了。

6、整体的执行流程

在这里插入图片描述
几个小tip:
1、共包含了两次解码过程。
2、先进行粗略的筛选再细调。
3、第一次获得的建议框解码后的结果是对共享特征层featuremap进行截取。

二、训练部分

Faster-RCNN的训练过程和它的预测过程一样,分为两部分,首先要训练获得建议框网络,然后再训练后面利用ROI获得预测结果的网络。

1、建议框网络的训练

公用特征层如果要获得建议框的预测结果,需要再进行一次3x3的卷积后,进行一个9通道的1x1卷积,还有一个36通道的1x1卷积。

在Faster-RCNN中,num_priors也就是先验框的数量就是9,所以两个1x1卷积的结果实际上也就是:

9 x 4的卷积 用于预测 公用特征层上 每一个网格点上 每一个先验框的变化情况。(为什么说是变化情况呢,这是因为Faster-RCNN的预测结果需要结合先验框获得预测框,预测结果就是先验框的变化情况。)

9 x 1的卷积 用于预测 公用特征层上 每一个网格点上 每一个预测框内部是否包含了物体。

也就是说,我们直接利用Faster-RCNN建议框网络预测到的结果,并不是建议框在图片上的真实位置,需要解码才能得到真实位置。

而在训练的时候,我们需要计算loss函数,这个loss函数是相对于Faster-RCNN建议框网络的预测结果的。我们需要把图片输入到当前的Faster-RCNN建议框的网络中,得到建议框的结果;同时还需要进行编码,这个编码是把真实框的位置信息格式转化为Faster-RCNN建议框预测结果的格式信息

也就是,我们需要找到 每一张用于训练的图片每一个真实框对应的先验框,并求出如果想要得到这样一个真实框,我们的建议框预测结果应该是怎么样的。

从建议框预测结果获得真实框的过程被称作解码,而从真实框获得建议框预测结果的过程就是编码的过程。

因此我们只需要将解码过程逆过来就是编码过程了。

实现代码如下:

def encode_box(self, box, return_iou=True):
    iou = self.iou(box)
    encoded_box = np.zeros((self.num_priors, 4 + return_iou))
<span class="token comment"># 找到每一个真实框,重合程度较高的先验框</span>
assign_mask <span class="token operator">=</span> iou <span class="token operator">&gt;</span> self<span class="token punctuation">.</span>overlap_threshold
<span class="token keyword">if</span> <span class="token operator">not</span> assign_mask<span class="token punctuation">.</span><span class="token builtin">any</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    assign_mask<span class="token punctuation">[</span>iou<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token boolean">True</span>
<span class="token keyword">if</span> return_iou<span class="token punctuation">:</span>
    encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> iou<span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span>

<span class="token comment"># 找到对应的先验框</span>
assigned_priors <span class="token operator">=</span> self<span class="token punctuation">.</span>priors<span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span>
<span class="token comment"># 逆向编码,将真实框转化为Retinanet预测结果的格式</span>
<span class="token comment"># 先计算真实框的中心与长宽</span>
box_center <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span> box<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
box_wh <span class="token operator">=</span> box<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span> <span class="token operator">-</span> box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span>
<span class="token comment"># 再计算重合度较高的先验框的中心与长宽</span>
assigned_priors_center <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span>
                                assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
assigned_priors_wh <span class="token operator">=</span> <span class="token punctuation">(</span>assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">-</span>
                        assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 逆向求取ssd应该有的预测结果</span>
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> box_center <span class="token operator">-</span> assigned_priors_center
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">/=</span> assigned_priors_wh
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">*=</span> <span class="token number">4</span>

encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>log<span class="token punctuation">(</span>box_wh <span class="token operator">/</span> assigned_priors_wh<span class="token punctuation">)</span>
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">*=</span> <span class="token number">4</span>
<span class="token keyword">return</span> encoded_box<span class="token punctuation">.</span>ravel<span class="token punctuation">(</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

利用上述代码我们可以获得,真实框对应的所有的iou较大先验框,并计算了真实框对应的所有iou较大的先验框应该有的预测结果。

但是由于原始图片中可能存在多个真实框,可能同一个先验框会与多个真实框重合度较高,我们只取其中与真实框重合度最高的就可以了。

因此我们还要经过一次筛选,将上述代码获得的真实框对应的所有的iou较大先验框的预测结果中,iou最大的那个真实框筛选出来。

通过assign_boxes我们就获得了,输入进来的这张图片,应该有的预测结果是什么样子的。

实现代码如下:

def iou(self, box):
    # 计算出每个真实框与所有的先验框的iou
    # 判断真实框与先验框的重合情况
    inter_upleft = np.maximum(self.priors[:, :2], box[:2])
    inter_botright = np.minimum(self.priors[:, 2:4], box[2:])
inter_wh <span class="token operator">=</span> inter_botright <span class="token operator">-</span> inter_upleft
inter_wh <span class="token operator">=</span> np<span class="token punctuation">.</span>maximum<span class="token punctuation">(</span>inter_wh<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span>
inter <span class="token operator">=</span> inter_wh<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> inter_wh<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span>
<span class="token comment"># 真实框的面积</span>
area_true <span class="token operator">=</span> <span class="token punctuation">(</span>box<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">-</span> box<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span>box<span class="token punctuation">[</span><span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">-</span> box<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># 先验框的面积</span>
area_gt <span class="token operator">=</span> <span class="token punctuation">(</span>self<span class="token punctuation">.</span>priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">-</span> self<span class="token punctuation">.</span>priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token operator">*</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">-</span> self<span class="token punctuation">.</span>priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># 计算iou</span>
union <span class="token operator">=</span> area_true <span class="token operator">+</span> area_gt <span class="token operator">-</span> inter

iou <span class="token operator">=</span> inter <span class="token operator">/</span> union
<span class="token keyword">return</span> iou

def encode_box(self, box, return_iou=True):
iou = self.iou(box)
encoded_box = np.zeros((self.num_priors, 4 + return_iou))

<span class="token comment"># 找到每一个真实框,重合程度较高的先验框</span>
assign_mask <span class="token operator">=</span> iou <span class="token operator">&gt;</span> self<span class="token punctuation">.</span>overlap_threshold
<span class="token keyword">if</span> <span class="token operator">not</span> assign_mask<span class="token punctuation">.</span><span class="token builtin">any</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    assign_mask<span class="token punctuation">[</span>iou<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token boolean">True</span>
<span class="token keyword">if</span> return_iou<span class="token punctuation">:</span>
    encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> iou<span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span>

<span class="token comment"># 找到对应的先验框</span>
assigned_priors <span class="token operator">=</span> self<span class="token punctuation">.</span>priors<span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span>
<span class="token comment"># 逆向编码,将真实框转化为Retinanet预测结果的格式</span>
<span class="token comment"># 先计算真实框的中心与长宽</span>
box_center <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span> box<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
box_wh <span class="token operator">=</span> box<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span> <span class="token operator">-</span> box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span>
<span class="token comment"># 再计算重合度较高的先验框的中心与长宽</span>
assigned_priors_center <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">+</span>
                                assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
assigned_priors_wh <span class="token operator">=</span> <span class="token punctuation">(</span>assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">-</span>
                        assigned_priors<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

<span class="token comment"># 逆向求取ssd应该有的预测结果</span>
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> box_center <span class="token operator">-</span> assigned_priors_center
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">/=</span> assigned_priors_wh
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">*=</span> <span class="token number">4</span>

encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>log<span class="token punctuation">(</span>box_wh <span class="token operator">/</span> assigned_priors_wh<span class="token punctuation">)</span>
encoded_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">*=</span> <span class="token number">4</span>
<span class="token keyword">return</span> encoded_box<span class="token punctuation">.</span>ravel<span class="token punctuation">(</span><span class="token punctuation">)</span>

def ignore_box(self, box):
iou = self.iou(box)

ignored_box <span class="token operator">=</span> np<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>num_priors<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment"># 找到每一个真实框,重合程度较高的先验框</span>
assign_mask <span class="token operator">=</span> <span class="token punctuation">(</span>iou <span class="token operator">&gt;</span> self<span class="token punctuation">.</span>ignore_threshold<span class="token punctuation">)</span><span class="token operator">&amp;</span><span class="token punctuation">(</span>iou<span class="token operator">&lt;</span>self<span class="token punctuation">.</span>overlap_threshold<span class="token punctuation">)</span>

<span class="token keyword">if</span> <span class="token operator">not</span> assign_mask<span class="token punctuation">.</span><span class="token builtin">any</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    assign_mask<span class="token punctuation">[</span>iou<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token boolean">True</span>
    
ignored_box<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span> <span class="token operator">=</span> iou<span class="token punctuation">[</span>assign_mask<span class="token punctuation">]</span>
<span class="token keyword">return</span> ignored_box<span class="token punctuation">.</span>ravel<span class="token punctuation">(</span><span class="token punctuation">)</span>

def assign_boxes(self, boxes, anchors):
self.num_priors = len(anchors)
self.priors = anchors
assignment = np.zeros((self.num_priors, 4 + 1))

assignment<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">0.0</span>
<span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>boxes<span class="token punctuation">)</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> assignment
    
<span class="token comment"># 对每一个真实框都进行iou计算</span>
ingored_boxes <span class="token operator">=</span> np<span class="token punctuation">.</span>apply_along_axis<span class="token punctuation">(</span>self<span class="token punctuation">.</span>ignore_box<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># 取重合程度最大的先验框,并且获取这个先验框的index</span>
ingored_boxes <span class="token operator">=</span> ingored_boxes<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>num_priors<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
<span class="token comment"># (num_priors)</span>
ignore_iou <span class="token operator">=</span> ingored_boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
<span class="token comment"># (num_priors)</span>
ignore_iou_mask <span class="token operator">=</span> ignore_iou <span class="token operator">&gt;</span> <span class="token number">0</span>

assignment<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>ignore_iou_mask<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span>

<span class="token comment"># (n, num_priors, 5)</span>
encoded_boxes <span class="token operator">=</span> np<span class="token punctuation">.</span>apply_along_axis<span class="token punctuation">(</span>self<span class="token punctuation">.</span>encode_box<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token comment"># 每一个真实框的编码后的值,和iou</span>
<span class="token comment"># (n, num_priors)</span>
encoded_boxes <span class="token operator">=</span> encoded_boxes<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>num_priors<span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">)</span>

<span class="token comment"># 取重合程度最大的先验框,并且获取这个先验框的index</span>
<span class="token comment"># (num_priors)</span>
best_iou <span class="token operator">=</span> encoded_boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
<span class="token comment"># (num_priors)</span>
best_iou_idx <span class="token operator">=</span> encoded_boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
<span class="token comment"># (num_priors)</span>
best_iou_mask <span class="token operator">=</span> best_iou <span class="token operator">&gt;</span> <span class="token number">0</span>
<span class="token comment"># 某个先验框它属于哪个真实框</span>
best_iou_idx <span class="token operator">=</span> best_iou_idx<span class="token punctuation">[</span>best_iou_mask<span class="token punctuation">]</span>

assign_num <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>best_iou_idx<span class="token punctuation">)</span>
<span class="token comment"># 保留重合程度最大的先验框的应该有的预测结果</span>
<span class="token comment"># 哪些先验框存在真实框</span>
encoded_boxes <span class="token operator">=</span> encoded_boxes<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> best_iou_mask<span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span>

assignment<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>best_iou_mask<span class="token punctuation">]</span> <span class="token operator">=</span> encoded_boxes<span class="token punctuation">[</span>best_iou_idx<span class="token punctuation">,</span>np<span class="token punctuation">.</span>arange<span class="token punctuation">(</span>assign_num<span class="token punctuation">)</span><span class="token punctuation">,</span><span class="token punctuation">:</span><span class="token number">4</span><span class="token punctuation">]</span>
<span class="token comment"># 4代表为背景的概率,为0</span>
assignment<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">]</span><span class="token punctuation">[</span>best_iou_mask<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">1</span>
<span class="token comment"># 通过assign_boxes我们就获得了,输入进来的这张图片,应该有的预测结果是什么样子的</span>
<span class="token keyword">return</span> assignment
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113

focal会忽略一些重合度相对较高但是不是非常高的先验框,一般将重合度在0.3-0.7之间的先验框进行忽略。

2、Roi网络的训练

通过上一步已经可以对建议框网络进行训练了,建议框网络会提供一些位置的建议,在ROI网络部分,其会将建议框根据进行一定的截取,并获得对应的预测结果,事实上就是将上一步建议框当作了ROI网络的先验框。

因此,我们需要计算所有建议框和真实框的重合程度,并进行筛选,如果某个真实框和建议框的重合程度大于0.5则认为该建议框为正样本,如果重合程度小于0.5大于0.1则认为该建议框为负样本

因此我们可以对真实框进行编码,这个编码是相对于建议框的,也就是,当我们存在这些建议框的时候,我们的ROI预测网络需要有什么样的预测结果才能将这些建议框调整成真实框。

每次训练我们都放入32个建议框进行训练,同时要注意正负样本的平衡。
实现代码如下:

# 编码
def calc_iou(R, config, all_boxes, width, height, num_classes):
    # print(all_boxes)
    bboxes = all_boxes[:,:4]
    gta = np.zeros((len(bboxes), 4))
    for bbox_num, bbox in enumerate(bboxes):
        gta[bbox_num, 0] = int(round(bbox[0]*width/config.rpn_stride))
        gta[bbox_num, 1] = int(round(bbox[1]*height/config.rpn_stride))
        gta[bbox_num, 2] = int(round(bbox[2]*width/config.rpn_stride))
        gta[bbox_num, 3] = int(round(bbox[3]*height/config.rpn_stride))
    x_roi = []
    y_class_num = []
    y_class_regr_coords = []
    y_class_regr_label = []
    IoUs = []
    # print(gta)
    for ix in range(R.shape[0]):
        x1 = R[ix, 0]*width/config.rpn_stride
        y1 = R[ix, 1]*height/config.rpn_stride
        x2 = R[ix, 2]*width/config.rpn_stride
        y2 = R[ix, 3]*height/config.rpn_stride
    x1 <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token builtin">round</span><span class="token punctuation">(</span>x1<span class="token punctuation">)</span><span class="token punctuation">)</span>
    y1 <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token builtin">round</span><span class="token punctuation">(</span>y1<span class="token punctuation">)</span><span class="token punctuation">)</span>
    x2 <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token builtin">round</span><span class="token punctuation">(</span>x2<span class="token punctuation">)</span><span class="token punctuation">)</span>
    y2 <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token builtin">round</span><span class="token punctuation">(</span>y2<span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token comment"># print([x1, y1, x2, y2])</span>
    best_iou <span class="token operator">=</span> <span class="token number">0.0</span>
    best_bbox <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span>
    <span class="token keyword">for</span> bbox_num <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>bboxes<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
        curr_iou <span class="token operator">=</span> iou<span class="token punctuation">(</span><span class="token punctuation">[</span>gta<span class="token punctuation">[</span>bbox_num<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> gta<span class="token punctuation">[</span>bbox_num<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> gta<span class="token punctuation">[</span>bbox_num<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">,</span> gta<span class="token punctuation">[</span>bbox_num<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span>x1<span class="token punctuation">,</span> y1<span class="token punctuation">,</span> x2<span class="token punctuation">,</span> y2<span class="token punctuation">]</span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> curr_iou <span class="token operator">&gt;</span> best_iou<span class="token punctuation">:</span>
            best_iou <span class="token operator">=</span> curr_iou
            best_bbox <span class="token operator">=</span> bbox_num
    <span class="token comment"># print(best_iou)</span>
    <span class="token keyword">if</span> best_iou <span class="token operator">&lt;</span> config<span class="token punctuation">.</span>classifier_min_overlap<span class="token punctuation">:</span>
        <span class="token keyword">continue</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        w <span class="token operator">=</span> x2 <span class="token operator">-</span> x1
        h <span class="token operator">=</span> y2 <span class="token operator">-</span> y1
        x_roi<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">[</span>x1<span class="token punctuation">,</span> y1<span class="token punctuation">,</span> w<span class="token punctuation">,</span> h<span class="token punctuation">]</span><span class="token punctuation">)</span>
        IoUs<span class="token punctuation">.</span>append<span class="token punctuation">(</span>best_iou<span class="token punctuation">)</span>

        <span class="token keyword">if</span> config<span class="token punctuation">.</span>classifier_min_overlap <span class="token operator">&lt;=</span> best_iou <span class="token operator">&lt;</span> config<span class="token punctuation">.</span>classifier_max_overlap<span class="token punctuation">:</span>
            label <span class="token operator">=</span> <span class="token operator">-</span><span class="token number">1</span>
        <span class="token keyword">elif</span> config<span class="token punctuation">.</span>classifier_max_overlap <span class="token operator">&lt;=</span> best_iou<span class="token punctuation">:</span>
            
            label <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>all_boxes<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
            cxg <span class="token operator">=</span> <span class="token punctuation">(</span>gta<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">+</span> gta<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token number">2.0</span>
            cyg <span class="token operator">=</span> <span class="token punctuation">(</span>gta<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">+</span> gta<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token number">2.0</span>

            cx <span class="token operator">=</span> x1 <span class="token operator">+</span> w <span class="token operator">/</span> <span class="token number">2.0</span>
            cy <span class="token operator">=</span> y1 <span class="token operator">+</span> h <span class="token operator">/</span> <span class="token number">2.0</span>

            tx <span class="token operator">=</span> <span class="token punctuation">(</span>cxg <span class="token operator">-</span> cx<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token builtin">float</span><span class="token punctuation">(</span>w<span class="token punctuation">)</span>
            ty <span class="token operator">=</span> <span class="token punctuation">(</span>cyg <span class="token operator">-</span> cy<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token builtin">float</span><span class="token punctuation">(</span>h<span class="token punctuation">)</span>
            tw <span class="token operator">=</span> np<span class="token punctuation">.</span>log<span class="token punctuation">(</span><span class="token punctuation">(</span>gta<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token operator">-</span> gta<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token builtin">float</span><span class="token punctuation">(</span>w<span class="token punctuation">)</span><span class="token punctuation">)</span>
            th <span class="token operator">=</span> np<span class="token punctuation">.</span>log<span class="token punctuation">(</span><span class="token punctuation">(</span>gta<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">]</span> <span class="token operator">-</span> gta<span class="token punctuation">[</span>best_bbox<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token builtin">float</span><span class="token punctuation">(</span>h<span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">else</span><span class="token punctuation">:</span>
            <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'roi = {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>best_iou<span class="token punctuation">)</span><span class="token punctuation">)</span>
            <span class="token keyword">raise</span> RuntimeError
    <span class="token comment"># print(label)</span>
    class_label <span class="token operator">=</span> num_classes <span class="token operator">*</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
    class_label<span class="token punctuation">[</span>label<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">1</span>
    y_class_num<span class="token punctuation">.</span>append<span class="token punctuation">(</span>copy<span class="token punctuation">.</span>deepcopy<span class="token punctuation">(</span>class_label<span class="token punctuation">)</span><span class="token punctuation">)</span>
    coords <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token number">4</span> <span class="token operator">*</span> <span class="token punctuation">(</span>num_classes <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span>
    labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token number">4</span> <span class="token operator">*</span> <span class="token punctuation">(</span>num_classes <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span>
    <span class="token keyword">if</span> label <span class="token operator">!=</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">:</span>
        label_pos <span class="token operator">=</span> <span class="token number">4</span> <span class="token operator">*</span> label
        sx<span class="token punctuation">,</span> sy<span class="token punctuation">,</span> sw<span class="token punctuation">,</span> sh <span class="token operator">=</span> config<span class="token punctuation">.</span>classifier_regr_std
        coords<span class="token punctuation">[</span>label_pos<span class="token punctuation">:</span><span class="token number">4</span><span class="token operator">+</span>label_pos<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token punctuation">[</span>sx<span class="token operator">*</span>tx<span class="token punctuation">,</span> sy<span class="token operator">*</span>ty<span class="token punctuation">,</span> sw<span class="token operator">*</span>tw<span class="token punctuation">,</span> sh<span class="token operator">*</span>th<span class="token punctuation">]</span>
        labels<span class="token punctuation">[</span>label_pos<span class="token punctuation">:</span><span class="token number">4</span><span class="token operator">+</span>label_pos<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span>
        y_class_regr_coords<span class="token punctuation">.</span>append<span class="token punctuation">(</span>copy<span class="token punctuation">.</span>deepcopy<span class="token punctuation">(</span>coords<span class="token punctuation">)</span><span class="token punctuation">)</span>
        y_class_regr_label<span class="token punctuation">.</span>append<span class="token punctuation">(</span>copy<span class="token punctuation">.</span>deepcopy<span class="token punctuation">(</span>labels<span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        y_class_regr_coords<span class="token punctuation">.</span>append<span class="token punctuation">(</span>copy<span class="token punctuation">.</span>deepcopy<span class="token punctuation">(</span>coords<span class="token punctuation">)</span><span class="token punctuation">)</span>
        y_class_regr_label<span class="token punctuation">.</span>append<span class="token punctuation">(</span>copy<span class="token punctuation">.</span>deepcopy<span class="token punctuation">(</span>labels<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>x_roi<span class="token punctuation">)</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token boolean">None</span>

X <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>x_roi<span class="token punctuation">)</span>
<span class="token comment"># print(X)</span>
Y1 <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>y_class_num<span class="token punctuation">)</span>
Y2 <span class="token operator">=</span> np<span class="token punctuation">.</span>concatenate<span class="token punctuation">(</span><span class="token punctuation">[</span>np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>y_class_regr_label<span class="token punctuation">)</span><span class="token punctuation">,</span>np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>y_class_regr_coords<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">,</span>axis<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>

<span class="token keyword">return</span> np<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>X<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> np<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>Y1<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> np<span class="token punctuation">.</span>expand_dims<span class="token punctuation">(</span>Y2<span class="token punctuation">,</span> axis<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> IoUs

# 正负样本平衡
X2, Y1, Y2, IouS = calc_iou(R, config, boxes[0], width, height, NUM_CLASSES)

if X2 is None:
rpn_accuracy_rpn_monitor.append(0)
rpn_accuracy_for_epoch.append(0)
continue

neg_samples = np.where(Y1[0, :, -1] 1)
pos_samples = np.where(Y1[0, :, -1] 0)

if len(neg_samples) > 0:
neg_samples = neg_samples[0]
else:
neg_samples = []

if len(pos_samples) > 0:
pos_samples = pos_samples[0]
else:
pos_samples = []

rpn_accuracy_rpn_monitor.append(len(pos_samples))
rpn_accuracy_for_epoch.append((len(pos_samples)))

if len(neg_samples)==0:
continue

if len(pos_samples) < config.num_rois//2:
selected_pos_samples = pos_samples.tolist()
else:
selected_pos_samples = np.random.choice(pos_samples, config.num_rois//2, replace=False).tolist()
try:
selected_neg_samples = np.random.choice(neg_samples, config.num_rois - len(selected_pos_samples), replace=False).tolist()
except:
selected_neg_samples = np.random.choice(neg_samples, config.num_rois - len(selected_pos_samples), replace=True).tolist()

sel_samples = selected_pos_samples + selected_neg_samples
loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125

训练自己的Faster-RCNN模型

Faster-RCNN整体的文件夹构架如下:
在这里插入图片描述
本文使用VOC格式进行训练。
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
在这里插入图片描述
训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
在这里插入图片描述
在训练前利用voc2faster-rcnn.py文件生成对应的txt。
在这里插入图片描述
再运行根目录下的voc_annotation.py,运行前需要将classes改成你自己的classes。

classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]

 
 
  • 1

在这里插入图片描述
就会生成对应的2007_train.txt,每一行对应其图片位置及其真实框的位置。
在这里插入图片描述
在训练前需要修改model_data里面的voc_classes.txt文件,需要将classes改成你自己的classes。
在这里插入图片描述
运行train.py即可开始训练。
在这里插入图片描述

                                </div><div data-report-view="{&quot;mod&quot;:&quot;1585297308_001&quot;,&quot;dest&quot;:&quot;https://blog.csdn.net/weixin_44791964/article/details/104451667&quot;,&quot;extend1&quot;:&quot;pc&quot;,&quot;ab&quot;:&quot;new&quot;}"><div></div></div>
            <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-60ecaf1f42.css" rel="stylesheet">
                            </div>
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值