TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow

model=VGG16()#此处构建vgg16模型
my_scope='my'#外加的空间名
variables = tf.global_variables(my_scope)#获取模型中my_scope变量空间下的所有变量

file_name='vgg16.ckpt'#vgg16网络模型
# 1
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名

# 2
#model_variables=tf.train.load_checkpoint(file_name) 
#var_to_shape_map=model_variables.get_variable_to_shape_map()

print(var_to_shape_map)

sess=tf.Session()


variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
    if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
       print('Variables restored: %s' % v.name)
       variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
    elif v.name.split(':')[0] in var_to_shape_map:
       print('Variables restored: %s' % v.name)
       variables_to_restore[v.name]=v

restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

 

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,

<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,
 

vgg16.ckpt的fc6,fc7权重shape如下:

 'vgg_16/fc6/weights': [7, 7, 512, 4096],

 'vgg_16/fc7/weights': [1, 1, 4096, 4096],
 

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
                
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
                              "vgg_16/fc7/weights": fc7_conv,
                             })
restorer_fc.restore(sess, pretrained_model)

sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))   
                                                        
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值