5. /lib/networks/VGGnet_train.py
本文件是整个faster网络结构。从VGG到RPN,到RoI Proposal,最后是RCNN。
网络结构如下:
调用函数链接:
- 本文件调用的函数,如feed(), conv(), softmax()等都在network.py中。
- ‘rpn-data’层所用到的函数:anchor_target_layer
- ‘rpn-rois’层所用到的函数:proposal_layer
- ‘roi-data’层所用到的函数:proposal_target_layer
代码解读:
import tensorflow as tf
from networks.network import Network
#define
n_classes = 21
_feat_stride = [16,]
anchor_scales = [8, 16, 32]
class VGGnet_train(Network):
def __init__(self, trainable=True):
self.inputs = []
self.data = tf.placeholder(tf.float32, shape=[None, None, None, 3])
self.im_info = tf.placeholder(tf.float32, shape=[None, 3])
self.gt_boxes = tf.placeholder(tf.float32, shape=[None, 5])
#该参数定义dropout比例
self.keep_prob = tf.placeholder(tf.float32)
self.layers = dict({
'data':self.data, 'im_info':self.im_info, 'gt_boxes':self.gt_boxes})
self.trainable = trainable
self.setup()
# create ops and placeholders for bbox normalization process
#建立weights,biases变量,用tf.assign来更新
with tf.variable_scope('bbox_pred', reuse=True):
weights = tf.get_variable("weights")
biases = tf.get_variable("biases")
self.bbox_weights = tf.placeholder(weights.dtype, shape=weights.get_shape())
self.bbox_