源代码链接:
https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_utils.py
https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py
1、TensorFlow中resnet_v1_50的用法
首先总结一下用法,源码中resnet_v1_50的参数如下:
def resnet_v1_50(inputs,
num_classes=None,
is_training=True,
global_pool=True,
output_stride=None,
spatial_squeeze=True,
store_non_strided_activations=False,
min_base_depth=8,
depth_multiplier=1,
reuse=None,
scope='resnet_v1_50'):
其中:
- input:训练集,其格式为[batch, height_in, width_in, channels]
- num_classes:样本的种类别数,用于定义出上层的节点个数。如果为“None”的话,其最终输出的应该是[batch,1,1,2048],若“spatial_stride=True”,则其最终输出为[batch,2048]
- is_training:是否在训练模型中加入“Batch_Norm”层
- global_pool:该层位于整个网络结构之后,位于“num_classes”之前。为“True”则表示对于网络最后一个“net”层的输出结果做一个全局的average pooling。所谓全局池化就是池化的stride等于输入的size,得到一个标量。
- spatial_squeeze:将列表中维度等于1的维度去掉,如spatial_squeeze([B,1,1,C])