直接上代码探讨:
def batch_norm_layer(x, train_phase, scope_bn):##x is input-tensor, train_phase is tf.Variable(True/False)
with tf.variable_scope(scope_bn):
beta = tf.Variable(tf.constant(0.0, shape=[x.shape[-1]]), name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[x.shape[-1]]), name='gamma', trainable=True)
axises =list(range(len(x.shape) - 1))# np.arange(len(x.shape) - 1)
batch_mean, batch_var = tf.nn.moments(x, axises, name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.5)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(train_phase, mean_var_with_update,lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
def all_cnn_bn(inputs,num_classes=10, is_training=True, dropout_keep_prob=0.5,spatial_squeeze=True,scope='all_cnn',fc_conv_padding='VALID',
is_extracting=None):
"""
Note: All the fully_connected layers have been transformed to conv2d layers.
To use in classification mode, resize input to 32x32.
Args:
inputs: a tensor of size [batch_size, height, width, channels].
num_classes: number of predicted classes.
is_training: whether or not the model is being trained.
dropout_keep_prob: the probability that activations are kept in the dropout
layers during training.
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
outputs. Useful to remove unnecessary dimensions for classification.
scope: Optional scope for the variables.
fc_conv_padding: the type of padding to use for the fully connected layer
that is implemented as a convolutional layer. Use 'SAME' padding if you
are applying the network in a fully convolutional manner and want to
get a prediction map downsampled by a factor of 32 as an output.
Otherwise, the output prediction map will be (input / 32) - 6 in case of
'VALID' padding.
Returns:
the last op containing the log predictions and end_points dict.
"""
with tf.variable_scope(scope, 'all_cnn', [inputs]) as sc:
end_points_collection = sc.name + '_end_points'
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
outputs_collections=end_points_collection):
if is_training:
curtrain_phase = tf.Variable(True, trainable=False,name="cur_train_phase")
else:
curtrain_phase = tf.Variable(False, trainable=False,name="cur_train_phase")
inputs = slim.dropout(inputs, 0.9, is_training=is_training,scope='dropout_1')
net = slim.conv2d(inputs, 96, [3, 3], 1, scope='conv1_1')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_1')
net = slim.conv2d(net, 96, [3, 3], 1, scope='conv1_2')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_2')
net = slim.conv2d(net, 96, [3, 3],2, scope='conv2')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_3')
net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout_2')
net = slim.conv2d(net, 192, [3, 3], 1, scope='conv3_1')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_4')
net = slim.conv2d(net, 192, [3, 3], 1, scope='conv3_2')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_5')
net = slim.conv2d(net, 192, [3, 3],2, scope='conv4')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_6')
net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout_3')
extract_f = {}
net = slim.conv2d(net, 192, [3, 3], padding=fc_conv_padding,scope='conv5')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_7')
net = slim.conv2d(net, 192, [1, 1], padding=fc_conv_padding, scope='fc6')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_8')
net = slim.conv2d(net, num_classes, [1, 1], padding=fc_conv_padding, scope='fc7')
net = batch_norm_layer(x=net, train_phase=curtrain_phase, scope_bn='bn_9')
extract_f['f7'] =slim.flatten(net)
net = slim.avg_pool2d(net, [6,6], scope='avg_pool',padding=fc_conv_padding)
print(net.get_shape(),'avg_pool')
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if spatial_squeeze:
net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
end_points[sc.name + '/fc8'] = net
print("Net has been establed.", net.get_shape())
extract_f['f8'] = net
if is_extracting:
return extract_f
else:
return net, end_points