在上篇文章 https://blog.csdn.net/q199502092010/article/details/89381472 基础之上,参考了
https://www.jianshu.com/p/b9e38c1c94b1
https://www.jianshu.com/p/4b5ff96e70b5
这两篇文章,构建了基于tensorflow object detection api 中的faster rcnn vgg网络。
构建新的网络只需要更换新的特征提取器即可,上篇文章已将新的特征提取器加入到了config配置中,这篇文章主要用来修改具体特征提取网络的内容。vgg网络结构位于research/slim/net文件夹下。如果要简单能够使用的话,其实不需要更改其中的内容。但实际上通常我们将pool5之前的内容作为特征提取器,可以将不需要的网络层注释掉。
def vgg_a(inputs,
num_classes=1000,
is_training=True,
dropout_keep_prob=0.5,
spatial_squeeze=True,
scope='vgg_a',
fc_conv_padding='VALID',
global_pool=False):
"""Oxford Net VGG 11-Layers version A Example.
Note: All the fully_connected layers have been transformed to conv2d layers.
To use in classification mode, resize input to 224x224.
Args:
inputs: a tensor of size [batch_size, height, width, channels].
num_classes: number of predicted classes. If 0 or None, the logits layer is
omitted and the input features to the logits layer are returned instead.
is_training: whether or not the model is being trained.
dropout_keep_prob: the probability that activations are kept in the dropout
layers during training.
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
outputs. Useful to remove unnecessary dimensions for classification.
scope: Optional scope for the variables.
fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you
are applying the network in a fully convolutional manner and want to
get a prediction map downsampled by a factor of 32 as an output.
Otherwise, the output prediction map will be (input / 32) - 6 in case of
'VALID' padding.
global_pool: Optional boolean flag. If True, the input to the classification
layer is avgpooled to size 1x1, for any input size. (This is not part
of the original VGG architecture.)
Returns:
net: the output of the logits layer (if num_classes is a non-zero integer),
or the input to the logits layer (if num_classes is 0 or None).
end_points: a dict of tensors with intermediate activations.
"""
with tf.variable_scope(scope, &#