编写于2019/11/22
该方法适用于tensorflow中保存为ckpt文件的模型,最好是保存了所有参数和图的模型
import numpy as np
import tensorflow as tf
import scipy.io as sio
import os
with tf.Session() as sess:
# load the meta graph and weights
saver = tf.train.import_meta_graph('./logs/enet-clip1/model.ckpt-21390.meta')#键入模型名称.meta文件
saver.restore(sess, tf.train.latest_checkpoint('./logs/enet-clip1/'))#从checkpoints中恢复最新模型
graph = tf.get_default_graph()#得到模型的默认图
#按名字获得某一个参数并保存
# conv1_w = sess.run(graph.get_tensor_by_name('Conv1/W_conv1/W_conv1:0'))
# sio.savemat("./net/cnn_for_mnist/weights/conv1_w.mat", {"array": conv1_w})
#从中挑选模型参数,我所需的参数是所有的trainable_variables,以及非训练参数中的moving_mean,moving_variance,根据名称索引
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
variable_names =var_list + bn_moving_vars
#打印出所需参数的name,shape,info,value
values = sess.run(variable_names)
cnt = 0
for k,v in zip(variable_names, values):
print("Variable: ", k)
print("Shape: ", v.shape)
path = ('./logs/enet-clip1/parameters/'+k.name).rstrip(':0')
if not os.path.exists(path):
os.makedirs(path)
sio.savemat(path + '/{:s}.mat'.format(k.name.rstrip(':0').replace('/','-')), {'array': v})#保存需要的参数
else:
print('path existd!',path)
if 'ENet/bottleneck1_0_batch_norm1' in k.name:#选择性打印我所需的参数
print(v)
cnt+=1
print('cnt:',cnt)