6. /lib/networks/network.py
本文件是/lib/networks/VGGnet_train.py的支持文件,提供各种函数。
调用函数链接:
- ‘rpn-data’层所用到的函数:anchor_target_layer
- ‘rpn-rois’层所用到的函数:proposal_layer
- ‘roi-data’层所用到的函数:proposal_target_layer
代码解读:
import numpy as np
import tensorflow as tf
import roi_pooling_layer.roi_pooling_op as roi_pool_op
import roi_pooling_layer.roi_pooling_op_grad
from rpn_msr.proposal_layer_tf import proposal_layer as proposal_layer_py
from rpn_msr.anchor_target_layer_tf import anchor_target_layer as anchor_target_layer_py
from rpn_msr.proposal_target_layer_tf import proposal_target_layer as proposal_target_layer_py
DEFAULT_PADDING = 'SAME'
def layer(op):
def layer_decorated(self, *args, **kwargs):
# Automatically set a name if not provided.
#op.__name__的是各个操作函数名,如conv、max_pool
#get_unique_name返回类似与conv_4,以name:'conv_4'存在kwargs字典
name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
# Figure out the layer inputs.
if len(self.inputs)==0:
raise RuntimeError('No input variables found for layer %s.'%name)
#此情况说明刚有输入层,即取输入数据即可
elif len(self.inputs)==1:
layer_input = self.inputs[0]
else:
layer_input = list(self.inputs)
# Perform the operation and get the output.
#开始做卷积,做pool操作!!!!正式开始做操作的是这里,而不是函数定义,会发现下面函数定义中与所给参数个数不符合,原因在于input没给定
layer_output = op(self, layer_input, *args, **kwargs)
# Add to layer LUT.
#在self.layer中添加该name操作信息
self.layers[name] = layer_output
# This output is now the input for the next layer.
#将该output添加到inputs中
self.feed(layer_output)
# Return self for chained calls.
return self
return layer_decorated
class Network(object):
def __init__(self, inputs, trainable=True):
self.inputs = []
self.layers = dict(inputs)
self.trainable = trainable
self.setup()
def setup(self):
raise NotImplementedError('Must be subclassed.')
def load(self, data_path, session, saver, ignore_missing=False):
if data_path.endswith('.ckpt'):
saver.restore(session, data_path)
else:
data_dict = np.load(data_path).item()
for key in data_dict:
with tf.variable_scope(key, reuse=True):
for subkey in data_dict[key]:
try:
var = tf.get_variable(subkey)
session.run(var.assign(data_dict[key][subkey]))
print "assign pretrain model "+subkey+ " to "+key
except ValueError:
print "ignore "+key
if not ignore_missing:
raise