TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()
然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。
例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数
from tensorflow.python import pywrap_tensorflow
model=VGG16()#此处构建vgg16模型
my_scope='my'#外加的空间名
variables = tf.global_variables(my_scope)#获取模型中my_scope变量空间下的所有变量
file_name='vgg16.ckpt'#vgg16网络模型
# 1
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
# 2
#model_variables=tf.train.load_checkpoint(file_name)
#var_to_shape_map=model_variables.get_variable_to_shape_map()
print(var_to_shape_map)
sess=tf.Session()
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
print('Variables restored: %s' % v.name)
variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
elif v.name.split(':')[0] in var_to_shape_map:
print('Variables restored: %s' % v.name)
variables_to_restore[v.name]=v
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)
实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:
<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,
vgg16.ckpt的fc6,fc7权重shape如下:
'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],
因此,有如下操作:
fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
"vgg_16/fc7/weights": fc7_conv,
})
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))