ResNet50及其Keras实现

如果原理你已经了解,请直接到跳转ResNet50实现:卷积神经网络 第三周作业:Residual+Networks±+v1

你或许看过这篇访问量过12万的博客ResNet解析,但该博客的第一小节ResNet和吴恩达的叙述完全不同,因此博主对这篇博文持怀疑态度,你可以在这篇博文最下面找到提出该网络的论文链接,这篇博文可以作为研读这篇论文的基础。

ResNet = Residual Network
所有非残差网络都被称为平凡网络,这是一个原论文提出来的相对而言的概念。

残差网络是2015年由著名的Researcher Kaiming He(何凯明)提出的深度卷积网络,一经出世,便在ImageNet中斩获图像分类、检测、定位三项的冠军。 残差网络更容易优化,并且能够通过增加相当的深度来提高准确率。核心是解决了增加深度带来的副作用(退化问题),这样能够通过单纯地增加网络深度,来提高网络性能。

Motivation of ResNet

理论上来说,越深的深度神经网络能够计算出的特征越丰富,越能取得良好的效果,更深的神经网络的唯一缺点仅仅是你需要训练的参数十分庞大,导致其需要大量的计算资源。

但实际上,随着网络的加深,你会发现你梯度的大小(范数)急剧下降,这被称为梯度消失,这会导致学习速率非常缓慢。在极少数情况下也会出现梯度急剧上升,即梯度爆炸现象。表现在训练集上的准确度相较于浅层网络不但没有提高,反而会下降。

残差网络就是一种为了解决网络加深梯度消失现象而提出的网络。

What does Residual mean ?

假设经过网络中某几层所隐含的映射是H(X)H(X)H(X),其中X表示这几层网络的首层输入。如果多个非线性层表示一个足够复杂的Hypothesis,那么H(X)等价于一个同样逐渐逼近该Hypothesis的残差函数(residual function)F(X)=H(X)−XF(X) = H(X) - XF(X)=H(X)−X, 原函数可以表示为F(X)+XF(X) +XF(X)+X。H(X)H(X)H(X)和F(X)F(X)F(X)本质上都是对Hypothesis的一种逼近(近似)。

(以上译自论文原文)

基于该原理,ResNet中提出了2种映射,恒等映射(identity mapping)和残差映射(residual mapping), 恒等映射就是上图中跳过2层权重而把X直接送到后2层relu部分的映射,残差映射指平凡网络原来的部分。之所以称为恒等,因为你跳过了权重层,没有经过任何计算,即G(X)=XG(X)=XG(X)=X。

吴恩达在视频里对此的解释是,我们直接取某层之后的输出X作为输入,直接跳过一些连续的网络层,送到后面某层的relu之前。这么做使得一个Residual Block很容易的学习(因为它只做了一个relu操作)。

当然这需要满足X的维度和relu的维度一致。上图给出如何在维度不一致情况下通过Ws×a[l]W_s\times a^{[l]}Ws​×a[l]使得维度一致,WsW_sWs​可以仅仅是一个用0填充的矩阵,或者是需要学习的参数矩阵。

Shortcut Connection & Residual Block

吴恩达课程所提供的Residual,进行了简化,下面是原论文里的详细表示:

如图,曲线即表示了Shortcut Connection(近道连接),它跳过了2个权重层(抄了近道)。平凡网络的一部分层加上shortcut connection即构成了一个Residual Block。shortcut使得每一个残差块很容易地学习到恒等映射函数,并且在反向传播时使得梯度直接传播到更浅的层。

在平凡网络上多次使用Residual Block,就形成了Residual Network。当然使用多少次,在什么网络的什么位置使用,就需要高深的洞察力和对深度神经网络的充分了解,目前大家参考的就是原论文和相近论文,改变这些也可以构造出新的近似网络。

论文里提出的2种Residual Block。

其中左边的residual block保留了输入的dimension,而右边则是一个"bottleneck design",第一层把输入256维降低到64维,然后在第三层回复到256维,而shortcut/skip connection跳过这三层直接把输入送到了relu部分。

ResNet50

ResNet即共50层的参差网络,其中没有需要训练的参数的层,比如pooling layer,不参与计数。

原论文提出的常见的几种参差网络,主要是层数不同,50层和101层是最常见的。

50层的ResNet包含了Identity block(恒等块)和convolutional block(卷积块)2种结构,如下所示。

Identity block. Skip connection “skips over” 3 layers
The convolutional block

2种结构的主要差别是shortcut connection上是否进行了卷积操作。

以这两种模块为构建搭建的ResNet50如下图所示:

ResNet大致包括了5个Stage,或者叫做5种参数不同的卷积阶段,如上图所示。 (注:原论文把max pooling作为Stage 2的起始阶段)

filter of size 是三个卷积块的filter数目,而不是卷积核大小f,参数f如上表中50层ResNet那列所示,下面也有说明。

  • Zero-padding pads the input with a pad of (3,3)
  • Stage 1:
    • The 2D Convolution has 64 filters of shape (7,7) and uses a stride of (2,2). Its name is “conv1”.
    • BatchNorm is applied to the channels axis of the input.
    • MaxPooling uses a (3,3) window and a (2,2) stride.
  • Stage 2:
    • The convolutional block uses three set of filters of size [64,64,256], “f” is 3, “s” is 1 and the block is “a”.
    • The 2 identity blocks use three set of filters of size [64,64,256], “f” is 3 and the blocks are “b” and “c”.
  • Stage 3:
    • The convolutional block uses three set of filters of size [128,128,512], “f” is 3, “s” is 2 and the block is “a”.
    • The 3 identity blocks use three set of filters of size [128,128,512], “f” is 3 and the blocks are “b”, “c” and “d”.
  • Stage 4:
    • The convolutional block uses three set of filters of size [256, 256, 1024], “f” is 3, “s” is 2 and the block is “a”.
    • The 5 identity blocks use three set of filters of size [256, 256, 1024], “f” is 3 and the blocks are “b”, “c”, “d”, “e” and “f”.
  • Stage 5:
    • The convolutional block uses three set of filters of size [512, 512, 2048], “f” is 3, “s” is 2 and the block is “a”.
    • The 2 identity blocks use three set of filters of size [256, 256, 2048], “f” is 3 and the blocks are “b” and “c”.
    • The 2D Average Pooling uses a window of shape (2,2) and its name is “avg_pool”.

Keras实现

这个实现即为吴恩达深度学习系列视频的作业,如果你想完全掌握的话,强烈建议你参考这篇包含了作业完整过程和说明的博文:

卷积神经网络 第三周作业:Residual+Networks±+v1

下面是代码部分:

导入相应的库:
import numpy as np
import tensorflow as tf
from keras import layers
from keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D
from keras.models import Model, load_model
from keras.preprocessing import image
from keras.utils import layer_utils
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import preprocess_input
# import pydot
# from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.utils import plot_model
from resnets_utils import *
from keras.initializers import glorot_uniform
import scipy.misc
from matplotlib.pyplot import imshow
%matplotlib inline

import keras.backend as K
K.set_image_data_format(‘channels_last’)
K.set_learning_phase(1)

恒等块

# GRADED FUNCTION: identity_block

def identity_block(X, f, filters, stage, block):
“”"
Implementation of the identity block as defined in Figure 4

Arguments<span class="token punctuation">:</span>
X <span class="token operator">--</span> input tensor <span class="token keyword">of</span> <span class="token function">shape</span> <span class="token punctuation">(</span>m<span class="token punctuation">,</span> n_H_prev<span class="token punctuation">,</span> n_W_prev<span class="token punctuation">,</span> n_C_prev<span class="token punctuation">)</span>
f <span class="token operator">--</span> integer<span class="token punctuation">,</span> specifying the shape <span class="token keyword">of</span> the middle CONV's window <span class="token keyword">for</span> the main path
filters <span class="token operator">--</span> python list <span class="token keyword">of</span> integers<span class="token punctuation">,</span> defining the number <span class="token keyword">of</span> filters <span class="token keyword">in</span> the CONV layers <span class="token keyword">of</span> the main path
stage <span class="token operator">--</span> integer<span class="token punctuation">,</span> used to name the layers<span class="token punctuation">,</span> depending on their position <span class="token keyword">in</span> the network
block <span class="token operator">--</span> string<span class="token operator">/</span>character<span class="token punctuation">,</span> used to name the layers<span class="token punctuation">,</span> depending on their position <span class="token keyword">in</span> the network

Returns<span class="token punctuation">:</span>
X <span class="token operator">--</span> output <span class="token keyword">of</span> the identity block<span class="token punctuation">,</span> tensor <span class="token keyword">of</span> <span class="token function">shape</span> <span class="token punctuation">(</span>n_H<span class="token punctuation">,</span> n_W<span class="token punctuation">,</span> n_C<span class="token punctuation">)</span>
<span class="token string">""</span>"

# defining name basis
conv_name_base <span class="token operator">=</span> <span class="token string">"res"</span> <span class="token operator">+</span> <span class="token function">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">"_branch"</span>
bn_name_base   <span class="token operator">=</span> <span class="token string">"bn"</span>  <span class="token operator">+</span> <span class="token function">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">"_branch"</span>

# Retrieve Filters
F1<span class="token punctuation">,</span> F2<span class="token punctuation">,</span> F3 <span class="token operator">=</span> filters

# Save the input value<span class="token punctuation">.</span> You'll need <span class="token keyword">this</span> later to add back to the main path<span class="token punctuation">.</span> 
X_shortcut <span class="token operator">=</span> X

# First component <span class="token keyword">of</span> main path
X <span class="token operator">=</span> <span class="token function">Conv2D</span><span class="token punctuation">(</span>filters<span class="token operator">=</span>F1<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">"valid"</span><span class="token punctuation">,</span> 
           name<span class="token operator">=</span>conv_name_base<span class="token operator">+</span><span class="token string">"2a"</span><span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
#valid mean no padding <span class="token operator">/</span> glorot_uniform equal to Xaiver initialization <span class="token operator">-</span> Steve 

X <span class="token operator">=</span> <span class="token function">BatchNormalization</span><span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base <span class="token operator">+</span> <span class="token string">"2a"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Activation</span><span class="token punctuation">(</span><span class="token string">"relu"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
### START CODE HERE ###

# Second component <span class="token keyword">of</span> main <span class="token function">path</span> <span class="token punctuation">(</span>≈<span class="token number">3</span> lines<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Conv2D</span><span class="token punctuation">(</span>filters<span class="token operator">=</span>F2<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span>f<span class="token punctuation">,</span> f<span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">"same"</span><span class="token punctuation">,</span>
           name<span class="token operator">=</span>conv_name_base<span class="token operator">+</span><span class="token string">"2b"</span><span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">BatchNormalization</span><span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base<span class="token operator">+</span><span class="token string">"2b"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Activation</span><span class="token punctuation">(</span><span class="token string">"relu"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
# Third component <span class="token keyword">of</span> main <span class="token function">path</span> <span class="token punctuation">(</span>≈<span class="token number">2</span> lines<span class="token punctuation">)</span>


# Final step<span class="token punctuation">:</span> Add shortcut value to main path<span class="token punctuation">,</span> and pass it through a RELU <span class="token function">activation</span> <span class="token punctuation">(</span>≈<span class="token number">2</span> lines<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Conv2D</span><span class="token punctuation">(</span>filters<span class="token operator">=</span>F3<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">"valid"</span><span class="token punctuation">,</span>
           name<span class="token operator">=</span>conv_name_base<span class="token operator">+</span><span class="token string">"2c"</span><span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">BatchNormalization</span><span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> name<span class="token operator">=</span>bn_name_base<span class="token operator">+</span><span class="token string">"2c"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>

X <span class="token operator">=</span> <span class="token function">Add</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span><span class="token punctuation">[</span>X<span class="token punctuation">,</span> X_shortcut<span class="token punctuation">]</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Activation</span><span class="token punctuation">(</span><span class="token string">"relu"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
### END CODE HERE ###

<span class="token keyword">return</span> X</pre><h3>卷积块</h3><pre class="prism-token token  language-javascript"># GRADED FUNCTION<span class="token punctuation">:</span> convolutional_block

def convolutional_block(X, f, filters, stage, block, s = 2):
“”"
Implementation of the convolutional block as defined in Figure 4

Arguments<span class="token punctuation">:</span>
X <span class="token operator">--</span> input tensor <span class="token keyword">of</span> <span class="token function">shape</span> <span class="token punctuation">(</span>m<span class="token punctuation">,</span> n_H_prev<span class="token punctuation">,</span> n_W_prev<span class="token punctuation">,</span> n_C_prev<span class="token punctuation">)</span>
f <span class="token operator">--</span> integer<span class="token punctuation">,</span> specifying the shape <span class="token keyword">of</span> the middle CONV's window <span class="token keyword">for</span> the main path
filters <span class="token operator">--</span> python list <span class="token keyword">of</span> integers<span class="token punctuation">,</span> defining the number <span class="token keyword">of</span> filters <span class="token keyword">in</span> the CONV layers <span class="token keyword">of</span> the main path
stage <span class="token operator">--</span> integer<span class="token punctuation">,</span> used to name the layers<span class="token punctuation">,</span> depending on their position <span class="token keyword">in</span> the network
block <span class="token operator">--</span> string<span class="token operator">/</span>character<span class="token punctuation">,</span> used to name the layers<span class="token punctuation">,</span> depending on their position <span class="token keyword">in</span> the network
s <span class="token operator">--</span> Integer<span class="token punctuation">,</span> specifying the stride to be used

Returns<span class="token punctuation">:</span>
X <span class="token operator">--</span> output <span class="token keyword">of</span> the convolutional block<span class="token punctuation">,</span> tensor <span class="token keyword">of</span> <span class="token function">shape</span> <span class="token punctuation">(</span>n_H<span class="token punctuation">,</span> n_W<span class="token punctuation">,</span> n_C<span class="token punctuation">)</span>
<span class="token string">""</span>"

# defining name basis
conv_name_base <span class="token operator">=</span> <span class="token string">'res'</span> <span class="token operator">+</span> <span class="token function">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>
bn_name_base <span class="token operator">=</span> <span class="token string">'bn'</span> <span class="token operator">+</span> <span class="token function">str</span><span class="token punctuation">(</span>stage<span class="token punctuation">)</span> <span class="token operator">+</span> block <span class="token operator">+</span> <span class="token string">'_branch'</span>

# Retrieve Filters
F1<span class="token punctuation">,</span> F2<span class="token punctuation">,</span> F3 <span class="token operator">=</span> filters

# Save the input value
X_shortcut <span class="token operator">=</span> X


##### MAIN PATH #####
# First component <span class="token keyword">of</span> main path 
X <span class="token operator">=</span> <span class="token function">Conv2D</span><span class="token punctuation">(</span>F1<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides <span class="token operator">=</span> <span class="token punctuation">(</span>s<span class="token punctuation">,</span>s<span class="token punctuation">)</span><span class="token punctuation">,</span> name <span class="token operator">=</span> conv_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">'valid'</span><span class="token punctuation">,</span> kernel_initializer <span class="token operator">=</span> <span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">BatchNormalization</span><span class="token punctuation">(</span>axis <span class="token operator">=</span> <span class="token number">3</span><span class="token punctuation">,</span> name <span class="token operator">=</span> bn_name_base <span class="token operator">+</span> <span class="token string">'2a'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Activation</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>

### START CODE HERE ###

# Second component <span class="token keyword">of</span> main <span class="token function">path</span> <span class="token punctuation">(</span>≈<span class="token number">3</span> lines<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Conv2D</span><span class="token punctuation">(</span>F2<span class="token punctuation">,</span> <span class="token punctuation">(</span>f<span class="token punctuation">,</span> f<span class="token punctuation">)</span><span class="token punctuation">,</span> strides <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name <span class="token operator">=</span> conv_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">,</span>padding<span class="token operator">=</span><span class="token string">'same'</span><span class="token punctuation">,</span> kernel_initializer <span class="token operator">=</span> <span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">BatchNormalization</span><span class="token punctuation">(</span>axis <span class="token operator">=</span> <span class="token number">3</span><span class="token punctuation">,</span> name <span class="token operator">=</span> bn_name_base <span class="token operator">+</span> <span class="token string">'2b'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Activation</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>

# Third component <span class="token keyword">of</span> main <span class="token function">path</span> <span class="token punctuation">(</span>≈<span class="token number">2</span> lines<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Conv2D</span><span class="token punctuation">(</span>F3<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name <span class="token operator">=</span> conv_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">,</span>padding<span class="token operator">=</span><span class="token string">'valid'</span><span class="token punctuation">,</span> kernel_initializer <span class="token operator">=</span> <span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">BatchNormalization</span><span class="token punctuation">(</span>axis <span class="token operator">=</span> <span class="token number">3</span><span class="token punctuation">,</span> name <span class="token operator">=</span> bn_name_base <span class="token operator">+</span> <span class="token string">'2c'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>

##### SHORTCUT PATH #### <span class="token punctuation">(</span>≈<span class="token number">2</span> lines<span class="token punctuation">)</span>
X_shortcut <span class="token operator">=</span> <span class="token function">Conv2D</span><span class="token punctuation">(</span>F3<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides <span class="token operator">=</span> <span class="token punctuation">(</span>s<span class="token punctuation">,</span> s<span class="token punctuation">)</span><span class="token punctuation">,</span> name <span class="token operator">=</span> conv_name_base <span class="token operator">+</span> <span class="token string">'1'</span><span class="token punctuation">,</span>padding<span class="token operator">=</span><span class="token string">'valid'</span><span class="token punctuation">,</span> kernel_initializer <span class="token operator">=</span> <span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X_shortcut<span class="token punctuation">)</span>
X_shortcut <span class="token operator">=</span> <span class="token function">BatchNormalization</span><span class="token punctuation">(</span>axis <span class="token operator">=</span> <span class="token number">3</span><span class="token punctuation">,</span> name <span class="token operator">=</span> bn_name_base <span class="token operator">+</span> <span class="token string">'1'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X_shortcut<span class="token punctuation">)</span>

# Final step<span class="token punctuation">:</span> Add shortcut value to main path<span class="token punctuation">,</span> and pass it through a RELU <span class="token function">activation</span> <span class="token punctuation">(</span>≈<span class="token number">2</span> lines<span class="token punctuation">)</span>
X <span class="token operator">=</span> layers<span class="token punctuation">.</span><span class="token function">add</span><span class="token punctuation">(</span><span class="token punctuation">[</span>X<span class="token punctuation">,</span> X_shortcut<span class="token punctuation">]</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Activation</span><span class="token punctuation">(</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>

### END CODE HERE ###

<span class="token keyword">return</span> X</pre><h3>ResNet本尊</h3><pre class="prism-token token  language-javascript"># GRADED FUNCTION<span class="token punctuation">:</span> ResNet50

def ResNet50(input_shape = (64, 64, 3), classes = 6):
“”"
Implementation of the popular ResNet50 the following architecture:
CONV2D -> BATCHNORM -> RELU -> MAXPOOL -> CONVBLOCK -> IDBLOCK2 -> CONVBLOCK -> IDBLOCK3
-> CONVBLOCK -> IDBLOCK5 -> CONVBLOCK -> IDBLOCK2 -> AVGPOOL -> TOPLAYER

Arguments<span class="token punctuation">:</span>
input_shape <span class="token operator">--</span> shape <span class="token keyword">of</span> the images <span class="token keyword">of</span> the dataset
classes <span class="token operator">--</span> integer<span class="token punctuation">,</span> number <span class="token keyword">of</span> classes

Returns<span class="token punctuation">:</span>
model <span class="token operator">--</span> a <span class="token function">Model</span><span class="token punctuation">(</span><span class="token punctuation">)</span> instance <span class="token keyword">in</span> Keras
<span class="token string">""</span>"

# Define the input <span class="token keyword">as</span> a tensor <span class="token keyword">with</span> shape input_shape
X_input <span class="token operator">=</span> <span class="token function">Input</span><span class="token punctuation">(</span>input_shape<span class="token punctuation">)</span>


# Zero<span class="token operator">-</span>Padding
X <span class="token operator">=</span> <span class="token function">ZeroPadding2D</span><span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X_input<span class="token punctuation">)</span>

# Stage <span class="token number">1</span>
X <span class="token operator">=</span> <span class="token function">Conv2D</span><span class="token punctuation">(</span>filters<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">7</span><span class="token punctuation">,</span> <span class="token number">7</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">"conv"</span><span class="token punctuation">,</span>
           kernel_initializer<span class="token operator">=</span><span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">BatchNormalization</span><span class="token punctuation">(</span>axis<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">"bn_conv1"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Activation</span><span class="token punctuation">(</span><span class="token string">"relu"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">MaxPooling2D</span><span class="token punctuation">(</span>pool_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> strides<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>

# Stage <span class="token number">2</span>
X <span class="token operator">=</span> <span class="token function">convolutional_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"a"</span><span class="token punctuation">,</span> s<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"b"</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"c"</span><span class="token punctuation">)</span>
### START CODE HERE ###

# Stage <span class="token number">3</span> <span class="token punctuation">(</span>≈<span class="token number">4</span> lines<span class="token punctuation">)</span>
# The convolutional block uses three <span class="token keyword">set</span> <span class="token keyword">of</span> filters <span class="token keyword">of</span> size <span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span><span class="token number">128</span><span class="token punctuation">,</span><span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"f"</span> is <span class="token number">3</span><span class="token punctuation">,</span> <span class="token string">"s"</span> is <span class="token number">2</span> and the block is <span class="token string">"a"</span><span class="token punctuation">.</span>
# The <span class="token number">3</span> identity blocks use three <span class="token keyword">set</span> <span class="token keyword">of</span> filters <span class="token keyword">of</span> size <span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span><span class="token number">128</span><span class="token punctuation">,</span><span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"f"</span> is <span class="token number">3</span> and the blocks are <span class="token string">"b"</span><span class="token punctuation">,</span> <span class="token string">"c"</span> and <span class="token string">"d"</span><span class="token punctuation">.</span>
X <span class="token operator">=</span> <span class="token function">convolutional_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"a"</span><span class="token punctuation">,</span> s<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"b"</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"c"</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"d"</span><span class="token punctuation">)</span>

# Stage <span class="token number">4</span> <span class="token punctuation">(</span>≈<span class="token number">6</span> lines<span class="token punctuation">)</span>
# The convolutional block uses three <span class="token keyword">set</span> <span class="token keyword">of</span> filters <span class="token keyword">of</span> size <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"f"</span> is <span class="token number">3</span><span class="token punctuation">,</span> <span class="token string">"s"</span> is <span class="token number">2</span> and the block is <span class="token string">"a"</span><span class="token punctuation">.</span>
# The <span class="token number">5</span> identity blocks use three <span class="token keyword">set</span> <span class="token keyword">of</span> filters <span class="token keyword">of</span> size <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"f"</span> is <span class="token number">3</span> and the blocks are <span class="token string">"b"</span><span class="token punctuation">,</span> <span class="token string">"c"</span><span class="token punctuation">,</span> <span class="token string">"d"</span><span class="token punctuation">,</span> <span class="token string">"e"</span> and <span class="token string">"f"</span><span class="token punctuation">.</span>
X <span class="token operator">=</span> <span class="token function">convolutional_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"a"</span><span class="token punctuation">,</span> s<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"b"</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"c"</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"d"</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"e"</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"f"</span><span class="token punctuation">)</span>


# Stage <span class="token number">5</span> <span class="token punctuation">(</span>≈<span class="token number">3</span> lines<span class="token punctuation">)</span>
# The convolutional block uses three <span class="token keyword">set</span> <span class="token keyword">of</span> filters <span class="token keyword">of</span> size <span class="token punctuation">[</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">2048</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"f"</span> is <span class="token number">3</span><span class="token punctuation">,</span> <span class="token string">"s"</span> is <span class="token number">2</span> and the block is <span class="token string">"a"</span><span class="token punctuation">.</span>
# The <span class="token number">2</span> identity blocks use three <span class="token keyword">set</span> <span class="token keyword">of</span> filters <span class="token keyword">of</span> size <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">2048</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token string">"f"</span> is <span class="token number">3</span> and the blocks are <span class="token string">"b"</span> and <span class="token string">"c"</span><span class="token punctuation">.</span>
X <span class="token operator">=</span> <span class="token function">convolutional_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">2048</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"a"</span><span class="token punctuation">,</span> s<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">2048</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"b"</span><span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">identity_block</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> f<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> filters<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">2048</span><span class="token punctuation">]</span><span class="token punctuation">,</span> stage<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">,</span> block<span class="token operator">=</span><span class="token string">"c"</span><span class="token punctuation">)</span>

# filters should be <span class="token punctuation">[</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">2048</span><span class="token punctuation">]</span><span class="token punctuation">,</span> but it fail to be graded<span class="token punctuation">.</span> Use <span class="token punctuation">[</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">2048</span><span class="token punctuation">]</span> to pass the grading


# <span class="token function">AVGPOOL</span> <span class="token punctuation">(</span>≈<span class="token number">1</span> line<span class="token punctuation">)</span><span class="token punctuation">.</span> Use <span class="token string">"X = AveragePooling2D(...)(X)"</span>
# The 2D Average Pooling uses a window <span class="token keyword">of</span> <span class="token function">shape</span> <span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">)</span> and its name is <span class="token string">"avg_pool"</span><span class="token punctuation">.</span>
X <span class="token operator">=</span> <span class="token function">AveragePooling2D</span><span class="token punctuation">(</span>pool_size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token string">"same"</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>

### END CODE HERE ###

# output layer
X <span class="token operator">=</span> <span class="token function">Flatten</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>
X <span class="token operator">=</span> <span class="token function">Dense</span><span class="token punctuation">(</span>classes<span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">"softmax"</span><span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">"fc"</span><span class="token operator">+</span><span class="token function">str</span><span class="token punctuation">(</span>classes<span class="token punctuation">)</span><span class="token punctuation">,</span> kernel_initializer<span class="token operator">=</span><span class="token function">glorot_uniform</span><span class="token punctuation">(</span>seed<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span>


# Create model
model <span class="token operator">=</span> <span class="token function">Model</span><span class="token punctuation">(</span>inputs<span class="token operator">=</span>X_input<span class="token punctuation">,</span> outputs<span class="token operator">=</span>X<span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">"ResNet50"</span><span class="token punctuation">)</span>

<span class="token keyword">return</span> model</pre><h3>定义模型</h3><pre class="prism-token token  language-javascript">model <span class="token operator">=</span> <span class="token function">ResNet50</span><span class="token punctuation">(</span>input_shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> classes<span class="token operator">=</span><span class="token number">6</span><span class="token punctuation">)</span></pre><h3>编译模型</h3><pre class="prism-token token  language-javascript">model<span class="token punctuation">.</span><span class="token function">compile</span><span class="token punctuation">(</span>optimizer<span class="token operator">=</span><span class="token string">'adam'</span><span class="token punctuation">,</span> loss<span class="token operator">=</span><span class="token string">'categorical_crossentropy'</span><span class="token punctuation">,</span> metrics<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">'accuracy'</span><span class="token punctuation">]</span><span class="token punctuation">)</span></pre><h3>导入训练数据</h3><pre class="prism-token token  language-javascript">X_train_orig<span class="token punctuation">,</span> Y_train_orig<span class="token punctuation">,</span> X_test_orig<span class="token punctuation">,</span> Y_test_orig<span class="token punctuation">,</span> classes <span class="token operator">=</span> <span class="token function">load_dataset</span><span class="token punctuation">(</span><span class="token punctuation">)</span>

Normalize image vectors

X_train = X_train_orig/255.
X_test = X_test_orig/255.

“”"
def convert_to_one_hot(Y, C):
Y = np.eye(C)[Y.reshape(-1)].T
return Y
“”"

Convert training and test labels to one hot matrices

Y_train = convert_to_one_hot(Y_train_orig, 6).T
Y_test = convert_to_one_hot(Y_test_orig, 6).T

print ("number of training examples = " + str(X_train.shape[0]))
print ("number of test examples = " + str(X_test.shape[0]))
print ("X_train shape: " + str(X_train.shape))
print ("Y_train shape: " + str(Y_train.shape))
print ("X_test shape: " + str(X_test.shape))
print ("Y_test shape: " + str(Y_test.shape))

number of training examples = 1080
number of test examples = 120
X_train shape: (1080, 64, 64, 3)
Y_train shape: (1080, 6)
X_test shape: (120, 64, 64, 3)
Y_test shape: (120, 6)

训练模型 (可能耗费比较长的时间)

model.fit(X_train, Y_train, epochs = 20, batch_size = 32)
Epoch 1/20
1080/1080 [==========] - 268s 248ms/step - loss: 2.9721 - acc: 0.2898
Epoch 2/20
1080/1080 [==========] - 270s 250ms/step - loss: 1.8968 - acc: 0.3639
Epoch 3/20
1080/1080 [==========] - 268s 248ms/step - loss: 1.5796 - acc: 0.4463
Epoch 4/20
1080/1080 [==========] - 251s 233ms/step - loss: 1.2796 - acc: 0.5213
Epoch 5/20
1080/1080 [==========] - 260s 241ms/step - loss: 0.9278 - acc: 0.6722
Epoch 6/20
1080/1080 [==========] - 261s 242ms/step - loss: 0.7286 - acc: 0.7315
Epoch 7/20
1080/1080 [==========] - 258s 239ms/step - loss: 0.4950 - acc: 0.8324
Epoch 8/20
1080/1080 [==========] - 261s 241ms/step - loss: 0.3646 - acc: 0.8889
Epoch 9/20
1080/1080 [==========] - 258s 238ms/step - loss: 0.3135 - acc: 0.9019
Epoch 10/20
1080/1080 [==========] - 255s 237ms/step - loss: 0.1291 - acc: 0.9639
Epoch 11/20
1080/1080 [==========] - 253s 235ms/step - loss: 0.0814 - acc: 0.9704
Epoch 12/20
1080/1080 [==========] - 260s 240ms/step - loss: 0.0901 - acc: 0.9685
Epoch 13/20
1080/1080 [==========] - 260s 240ms/step - loss: 0.0848 - acc: 0.9694
Epoch 14/20
1080/1080 [==========] - 261s 242ms/step - loss: 0.0740 - acc: 0.9741
Epoch 15/20
1080/1080 [==========] - 258s 239ms/step - loss: 0.0488 - acc: 0.9833
Epoch 16/20
1080/1080 [==========] - 260s 241ms/step - loss: 0.0257 - acc: 0.9981
Epoch 17/20
1080/1080 [==========] - 259s 240ms/step - loss: 0.0029 - acc: 1.0000
Epoch 18/20
1080/1080 [==========] - 260s 241ms/step - loss: 0.0014 - acc: 1.0000
Epoch 19/20
1080/1080 [==========] - 257s 238ms/step - loss: 8.9325e-04 - acc: 1.0000
Epoch 20/20
1080/1080 [==========] - 255s 236ms/step - loss: 6.9667e-04 - acc: 1.0000
<keras.callbacks.History at 0x21761b11710>

评估准确度

preds = model.evaluate(X_test, Y_test)
print ("Loss = " + str(preds[0]))
print ("Test Accuracy = " + str(preds[1]))
120/120 [==========] - 6s 49ms/step
Loss = 0.11732131155828635
Test Accuracy = 0.9666666666666667

打印出模型的结构

model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================
input_3 (InputLayer) (None, 64, 64, 3) 0


zero_padding2d_3 (ZeroPadding2D (None, 70, 70, 3) 0 input_3[0][0]


conv (Conv2D) (None, 32, 32, 64) 9472 zero_padding2d_3[0][0]


bn_conv1 (BatchNormalization) (None, 32, 32, 64) 256 conv[0][0]


activation_96 (Activation) (None, 32, 32, 64) 0 bn_conv1[0][0]


max_pooling2d_3 (MaxPooling2D) (None, 15, 15, 64) 0 activation_96[0][0]


res2a_branch2a (Conv2D) (None, 15, 15, 64) 4160 max_pooling2d_3[0][0]


bn2a_branch2a (BatchNormalizati (None, 15, 15, 64) 256 res2a_branch2a[0][0]


activation_97 (Activation) (None, 15, 15, 64) 0 bn2a_branch2a[0][0]


res2a_branch2b (Conv2D) (None, 15, 15, 64) 36928 activation_97[0][0]


bn2a_branch2b (BatchNormalizati (None, 15, 15, 64) 256 res2a_branch2b[0][0]


activation_98 (Activation) (None, 15, 15, 64) 0 bn2a_branch2b[0][0]


res2a_branch2c (Conv2D) (None, 15, 15, 256) 16640 activation_98[0][0]


res2a_branch1 (Conv2D) (None, 15, 15, 256) 16640 max_pooling2d_3[0][0]


bn2a_branch2c (BatchNormalizati (None, 15, 15, 256) 1024 res2a_branch2c[0][0]


bn2a_branch1 (BatchNormalizatio (None, 15, 15, 256) 1024 res2a_branch1[0][0]


add_32 (Add) (None, 15, 15, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]


activation_99 (Activation) (None, 15, 15, 256) 0 add_32[0][0]


res2b_branch2a (Conv2D) (None, 15, 15, 64) 16448 activation_99[0][0]


bn2b_branch2a (BatchNormalizati (None, 15, 15, 64) 256 res2b_branch2a[0][0]


activation_100 (Activation) (None, 15, 15, 64) 0 bn2b_branch2a[0][0]


res2b_branch2b (Conv2D) (None, 15, 15, 64) 36928 activation_100[0][0]


bn2b_branch2b (BatchNormalizati (None, 15, 15, 64) 256 res2b_branch2b[0][0]


activation_101 (Activation) (None, 15, 15, 64) 0 bn2b_branch2b[0][0]


res2b_branch2c (Conv2D) (None, 15, 15, 256) 16640 activation_101[0][0]


bn2b_branch2c (BatchNormalizati (None, 15, 15, 256) 1024 res2b_branch2c[0][0]


add_33 (Add) (None, 15, 15, 256) 0 bn2b_branch2c[0][0]
activation_99[0][0]


activation_102 (Activation) (None, 15, 15, 256) 0 add_33[0][0]


res2c_branch2a (Conv2D) (None, 15, 15, 64) 16448 activation_102[0][0]


bn2c_branch2a (BatchNormalizati (None, 15, 15, 64) 256 res2c_branch2a[0][0]


activation_103 (Activation) (None, 15, 15, 64) 0 bn2c_branch2a[0][0]


res2c_branch2b (Conv2D) (None, 15, 15, 64) 36928 activation_103[0][0]


bn2c_branch2b (BatchNormalizati (None, 15, 15, 64) 256 res2c_branch2b[0][0]


activation_104 (Activation) (None, 15, 15, 64) 0 bn2c_branch2b[0][0]


res2c_branch2c (Conv2D) (None, 15, 15, 256) 16640 activation_104[0][0]


bn2c_branch2c (BatchNormalizati (None, 15, 15, 256) 1024 res2c_branch2c[0][0]


add_34 (Add) (None, 15, 15, 256) 0 bn2c_branch2c[0][0]
activation_102[0][0]


activation_105 (Activation) (None, 15, 15, 256) 0 add_34[0][0]


res3a_branch2a (Conv2D) (None, 15, 15, 128) 32896 activation_105[0][0]


bn3a_branch2a (BatchNormalizati (None, 15, 15, 128) 512 res3a_branch2a[0][0]


activation_106 (Activation) (None, 15, 15, 128) 0 bn3a_branch2a[0][0]


res3a_branch2b (Conv2D) (None, 15, 15, 128) 147584 activation_106[0][0]


bn3a_branch2b (BatchNormalizati (None, 15, 15, 128) 512 res3a_branch2b[0][0]


activation_107 (Activation) (None, 15, 15, 128) 0 bn3a_branch2b[0][0]


res3a_branch2c (Conv2D) (None, 15, 15, 512) 66048 activation_107[0][0]


res3a_branch1 (Conv2D) (None, 15, 15, 512) 131584 activation_105[0][0]


bn3a_branch2c (BatchNormalizati (None, 15, 15, 512) 2048 res3a_branch2c[0][0]


bn3a_branch1 (BatchNormalizatio (None, 15, 15, 512) 2048 res3a_branch1[0][0]


add_35 (Add) (None, 15, 15, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]


activation_108 (Activation) (None, 15, 15, 512) 0 add_35[0][0]


res3b_branch2a (Conv2D) (None, 15, 15, 128) 65664 activation_108[0][0]


bn3b_branch2a (BatchNormalizati (None, 15, 15, 128) 512 res3b_branch2a[0][0]


activation_109 (Activation) (None, 15, 15, 128) 0 bn3b_branch2a[0][0]


res3b_branch2b (Conv2D) (None, 15, 15, 128) 147584 activation_109[0][0]


bn3b_branch2b (BatchNormalizati (None, 15, 15, 128) 512 res3b_branch2b[0][0]


activation_110 (Activation) (None, 15, 15, 128) 0 bn3b_branch2b[0][0]


res3b_branch2c (Conv2D) (None, 15, 15, 512) 66048 activation_110[0][0]


bn3b_branch2c (BatchNormalizati (None, 15, 15, 512) 2048 res3b_branch2c[0][0]


add_36 (Add) (None, 15, 15, 512) 0 bn3b_branch2c[0][0]
activation_108[0][0]


activation_111 (Activation) (None, 15, 15, 512) 0 add_36[0][0]


res3c_branch2a (Conv2D) (None, 15, 15, 128) 65664 activation_111[0][0]


bn3c_branch2a (BatchNormalizati (None, 15, 15, 128) 512 res3c_branch2a[0][0]


activation_112 (Activation) (None, 15, 15, 128) 0 bn3c_branch2a[0][0]


res3c_branch2b (Conv2D) (None, 15, 15, 128) 147584 activation_112[0][0]


bn3c_branch2b (BatchNormalizati (None, 15, 15, 128) 512 res3c_branch2b[0][0]


activation_113 (Activation) (None, 15, 15, 128) 0 bn3c_branch2b[0][0]


res3c_branch2c (Conv2D) (None, 15, 15, 512) 66048 activation_113[0][0]


bn3c_branch2c (BatchNormalizati (None, 15, 15, 512) 2048 res3c_branch2c[0][0]


add_37 (Add) (None, 15, 15, 512) 0 bn3c_branch2c[0][0]
activation_111[0][0]


activation_114 (Activation) (None, 15, 15, 512) 0 add_37[0][0]


res3d_branch2a (Conv2D) (None, 15, 15, 128) 65664 activation_114[0][0]


bn3d_branch2a (BatchNormalizati (None, 15, 15, 128) 512 res3d_branch2a[0][0]


activation_115 (Activation) (None, 15, 15, 128) 0 bn3d_branch2a[0][0]


res3d_branch2b (Conv2D) (None, 15, 15, 128) 147584 activation_115[0][0]


bn3d_branch2b (BatchNormalizati (None, 15, 15, 128) 512 res3d_branch2b[0][0]


activation_116 (Activation) (None, 15, 15, 128) 0 bn3d_branch2b[0][0]


res3d_branch2c (Conv2D) (None, 15, 15, 512) 66048 activation_116[0][0]


bn3d_branch2c (BatchNormalizati (None, 15, 15, 512) 2048 res3d_branch2c[0][0]


add_38 (Add) (None, 15, 15, 512) 0 bn3d_branch2c[0][0]
activation_114[0][0]


activation_117 (Activation) (None, 15, 15, 512) 0 add_38[0][0]


res4a_branch2a (Conv2D) (None, 8, 8, 256) 131328 activation_117[0][0]


bn4a_branch2a (BatchNormalizati (None, 8, 8, 256) 1024 res4a_branch2a[0][0]


activation_118 (Activation) (None, 8, 8, 256) 0 bn4a_branch2a[0][0]


res4a_branch2b (Conv2D) (None, 8, 8, 256) 590080 activation_118[0][0]


bn4a_branch2b (BatchNormalizati (None, 8, 8, 256) 1024 res4a_branch2b[0][0]


activation_119 (Activation) (None, 8, 8, 256) 0 bn4a_branch2b[0][0]


res4a_branch2c (Conv2D) (None, 8, 8, 1024) 263168 activation_119[0][0]


res4a_branch1 (Conv2D) (None, 8, 8, 1024) 525312 activation_117[0][0]


bn4a_branch2c (BatchNormalizati (None, 8, 8, 1024) 4096 res4a_branch2c[0][0]


bn4a_branch1 (BatchNormalizatio (None, 8, 8, 1024) 4096 res4a_branch1[0][0]


add_39 (Add) (None, 8, 8, 1024) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]


activation_120 (Activation) (None, 8, 8, 1024) 0 add_39[0][0]


res4b_branch2a (Conv2D) (None, 8, 8, 256) 262400 activation_120[0][0]


bn4b_branch2a (BatchNormalizati (None, 8, 8, 256) 1024 res4b_branch2a[0][0]


activation_121 (Activation) (None, 8, 8, 256) 0 bn4b_branch2a[0][0]


res4b_branch2b (Conv2D) (None, 8, 8, 256) 590080 activation_121[0][0]


bn4b_branch2b (BatchNormalizati (None, 8, 8, 256) 1024 res4b_branch2b[0][0]


activation_122 (Activation) (None, 8, 8, 256) 0 bn4b_branch2b[0][0]


res4b_branch2c (Conv2D) (None, 8, 8, 1024) 263168 activation_122[0][0]


bn4b_branch2c (BatchNormalizati (None, 8, 8, 1024) 4096 res4b_branch2c[0][0]


add_40 (Add) (None, 8, 8, 1024) 0 bn4b_branch2c[0][0]
activation_120[0][0]


activation_123 (Activation) (None, 8, 8, 1024) 0 add_40[0][0]


res4c_branch2a (Conv2D) (None, 8, 8, 256) 262400 activation_123[0][0]


bn4c_branch2a (BatchNormalizati (None, 8, 8, 256) 1024 res4c_branch2a[0][0]


activation_124 (Activation) (None, 8, 8, 256) 0 bn4c_branch2a[0][0]


res4c_branch2b (Conv2D) (None, 8, 8, 256) 590080 activation_124[0][0]


bn4c_branch2b (BatchNormalizati (None, 8, 8, 256) 1024 res4c_branch2b[0][0]


activation_125 (Activation) (None, 8, 8, 256) 0 bn4c_branch2b[0][0]


res4c_branch2c (Conv2D) (None, 8, 8, 1024) 263168 activation_125[0][0]


bn4c_branch2c (BatchNormalizati (None, 8, 8, 1024) 4096 res4c_branch2c[0][0]


add_41 (Add) (None, 8, 8, 1024) 0 bn4c_branch2c[0][0]
activation_123[0][0]


activation_126 (Activation) (None, 8, 8, 1024) 0 add_41[0][0]


res4d_branch2a (Conv2D) (None, 8, 8, 256) 262400 activation_126[0][0]


bn4d_branch2a (BatchNormalizati (None, 8, 8, 256) 1024 res4d_branch2a[0][0]


activation_127 (Activation) (None, 8, 8, 256) 0 bn4d_branch2a[0][0]


res4d_branch2b (Conv2D) (None, 8, 8, 256) 590080 activation_127[0][0]


bn4d_branch2b (BatchNormalizati (None, 8, 8, 256) 1024 res4d_branch2b[0][0]


activation_128 (Activation) (None, 8, 8, 256) 0 bn4d_branch2b[0][0]


res4d_branch2c (Conv2D) (None, 8, 8, 1024) 263168 activation_128[0][0]


bn4d_branch2c (BatchNormalizati (None, 8, 8, 1024) 4096 res4d_branch2c[0][0]


add_42 (Add) (None, 8, 8, 1024) 0 bn4d_branch2c[0][0]
activation_126[0][0]


activation_129 (Activation) (None, 8, 8, 1024) 0 add_42[0][0]


res4e_branch2a (Conv2D) (None, 8, 8, 256) 262400 activation_129[0][0]


bn4e_branch2a (BatchNormalizati (None, 8, 8, 256) 1024 res4e_branch2a[0][0]


activation_130 (Activation) (None, 8, 8, 256) 0 bn4e_branch2a[0][0]


res4e_branch2b (Conv2D) (None, 8, 8, 256) 590080 activation_130[0][0]


bn4e_branch2b (BatchNormalizati (None, 8, 8, 256) 1024 res4e_branch2b[0][0]


activation_131 (Activation) (None, 8, 8, 256) 0 bn4e_branch2b[0][0]


res4e_branch2c (Conv2D) (None, 8, 8, 1024) 263168 activation_131[0][0]


bn4e_branch2c (BatchNormalizati (None, 8, 8, 1024) 4096 res4e_branch2c[0][0]


add_43 (Add) (None, 8, 8, 1024) 0 bn4e_branch2c[0][0]
activation_129[0][0]


activation_132 (Activation) (None, 8, 8, 1024) 0 add_43[0][0]


res5a_branch2a (Conv2D) (None, 4, 4, 512) 524800 activation_132[0][0]


bn5a_branch2a (BatchNormalizati (None, 4, 4, 512) 2048 res5a_branch2a[0][0]


activation_133 (Activation) (None, 4, 4, 512) 0 bn5a_branch2a[0][0]


res5a_branch2b (Conv2D) (None, 4, 4, 512) 2359808 activation_133[0][0]


bn5a_branch2b (BatchNormalizati (None, 4, 4, 512) 2048 res5a_branch2b[0][0]


activation_134 (Activation) (None, 4, 4, 512) 0 bn5a_branch2b[0][0]


res5a_branch2c (Conv2D) (None, 4, 4, 2048) 1050624 activation_134[0][0]


res5a_branch1 (Conv2D) (None, 4, 4, 2048) 2099200 activation_132[0][0]


bn5a_branch2c (BatchNormalizati (None, 4, 4, 2048) 8192 res5a_branch2c[0][0]


bn5a_branch1 (BatchNormalizatio (None, 4, 4, 2048) 8192 res5a_branch1[0][0]


add_44 (Add) (None, 4, 4, 2048) 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]


activation_135 (Activation) (None, 4, 4, 2048) 0 add_44[0][0]


res5b_branch2a (Conv2D) (None, 4, 4, 512) 1049088 activation_135[0][0]


bn5b_branch2a (BatchNormalizati (None, 4, 4, 512) 2048 res5b_branch2a[0][0]


activation_136 (Activation) (None, 4, 4, 512) 0 bn5b_branch2a[0][0]


res5b_branch2b (Conv2D) (None, 4, 4, 512) 2359808 activation_136[0][0]


bn5b_branch2b (BatchNormalizati (None, 4, 4, 512) 2048 res5b_branch2b[0][0]


activation_137 (Activation) (None, 4, 4, 512) 0 bn5b_branch2b[0][0]


res5b_branch2c (Conv2D) (None, 4, 4, 2048) 1050624 activation_137[0][0]


bn5b_branch2c (BatchNormalizati (None, 4, 4, 2048) 8192 res5b_branch2c[0][0]


add_45 (Add) (None, 4, 4, 2048) 0 bn5b_branch2c[0][0]
activation_135[0][0]


activation_138 (Activation) (None, 4, 4, 2048) 0 add_45[0][0]


res5c_branch2a (Conv2D) (None, 4, 4, 512) 1049088 activation_138[0][0]


bn5c_branch2a (BatchNormalizati (None, 4, 4, 512) 2048 res5c_branch2a[0][0]


activation_139 (Activation) (None, 4, 4, 512) 0 bn5c_branch2a[0][0]


res5c_branch2b (Conv2D) (None, 4, 4, 512) 2359808 activation_139[0][0]


bn5c_branch2b (BatchNormalizati (None, 4, 4, 512) 2048 res5c_branch2b[0][0]


activation_140 (Activation) (None, 4, 4, 512) 0 bn5c_branch2b[0][0]


res5c_branch2c (Conv2D) (None, 4, 4, 2048) 1050624 activation_140[0][0]


bn5c_branch2c (BatchNormalizati (None, 4, 4, 2048) 8192 res5c_branch2c[0][0]


add_46 (Add) (None, 4, 4, 2048) 0 bn5c_branch2c[0][0]
activation_138[0][0]


activation_141 (Activation) (None, 4, 4, 2048) 0 add_46[0][0]


average_pooling2d_3 (AveragePoo (None, 2, 2, 2048) 0 activation_141[0][0]


flatten_3 (Flatten) (None, 8192) 0 average_pooling2d_3[0][0]


fc6 (Dense) (None, 6) 49158 flatten_3[0][0]
==========<span class="token

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值