def R_Net(inputs,label=None,bbox_target=None,landmark_target=None,training=True):
with slim.arg_scope([slim.conv2d],
activation_fn = prelu,
weights_initializer=slim.xavier_initializer(),
biases_initializer=tf.zeros_initializer(),
weights_regularizer=slim.l2_regularizer(0.0005),
padding='valid'):
print (inputs.get_shape()) #(384, 24, 24, 3)
net = slim.conv2d(inputs, num_outputs=28, kernel_size=[3,3], stride=1, scope="conv1")
print (net.get_shape()) #(384, 22, 22, 28)
net = slim.max_pool2d(net, kernel_size=[3, 3], stride=2, scope="pool1", padding='SAME')
print(net.get_shape()) #(384, 11, 11, 28)
net = slim.conv2d(net,num_outputs=48,kernel_size=[3,3],stride=1,scope="conv2")
print(net.get_shape()) #(384, 9, 9, 48)
net = slim.max_pool2d(net,kernel_size=[3,3],stride=2,scope="pool2")
print(net.get_shape()) #(384, 4, 4, 48)
net = slim.conv2d(net,num_outputs=64,kernel_size=[2,2],stride=1,scope="conv3")
print(net.get_shape()) #(384, 3, 3, 64)
fc_flatten = slim.flatten(net)
print(fc_flatten.get_shape()) #(384, 576)
fc1 = slim.fully_connected(fc_flatten, num_outputs=128,scope="fc1")
print(fc1.get_shape()) #(384, 128)
#batch*2
cls_prob = slim.fully_connected(fc1,num_outputs=2,scope="cls_fc",activation_fn=tf.nn.softmax)
print(cls_prob.get_shape()) #(384,2)
#batch*4
bbox_pred = slim.fully_connected(fc1,num_outputs=4,scope="bbox_fc",activation_fn=None)
print(bbox_pred.get_shape()) #(384,4)
#batch*10
landmark_pred = slim.fully_connected(fc1,num_outputs=10,scope="landmark_fc",activation_fn=None)
print(landmark_pred.get_shape()) #(384, 10)
#train
if training:
cls_loss = cls_ohem(cls_prob,label)
bbox_loss = bbox_ohem(bbox_pred,bbox_target,label)
accuracy = cal_accuracy(cls_prob,label)
landmark_loss = landmark_ohem(landmark_pred,landmark_target,label)
L2_loss = tf.add_n(slim.losses.get_regularization_losses())
return cls_loss,bbox_loss,landmark_loss,L2_loss,accuracy
else:
return cls_prob,bbox_pred,landmark_pred