前言
需求来自声明saver时需要指定变量列表。
var_list = saver._var_list
print(type(var_list[0]))
#output : <class 'tensorflow.python.ops.variables.RefVariable'>
print(var_list[0])
#output : <tf.Variable 'g_b1:0' shape=(32768,) dtype=float32_ref>
也就是说,我们也需要凑成这种类型的var_list才能输入给saver。
直接用str类型的name组成一个list是无效的!!!
1.动态获取
1.1两种最朴素的方法
朴实而无华,好记又好用。
#朴素获取可训练变量
t_vars = tf.trainable_variables()
#朴素获取全部变量,包含声明training=False的变量
all_vars = tf.global_variables()
2.2使用tensorflow.contrib.slim
import tensorflow.contrib.slim as slim
#下面这行代码返回常规变量
#常规变量是slim里面与model变量对应的一个类型
regular_variables = slim.get_variables()
#你也可以直接
vars = slim.get_variables_to_restore()
slim是一个非常强大的工具,有兴趣可以深入了解。
#slim还支持各种筛选方法
#通过name
variables = slim.get_variables_by_name("d_")
#通过name后缀
variables = slim.get_variables_by_suffix("_b")
#通过namespace
variables = slim.get_variables(scope="layer1")
#通过include和exclude筛选名字
variables_to_restore = slim.get_variables_to_restore(include=["d_"])
variables_to_restore = slim.get_variables_to_restore(exclude=["_w"])
//参考https://blog.csdn.net/guvcolie/article/details/77686555
二.离线获取
从一个已保存好的model中获取var_list。
2.1 将离线文件载入当前环境,于是问题变回了“动态获取”
#记住,要先清空现有的图
#不然的话import_meta_graph会把原model里面的数据追加到现有的model中
#一片混乱
tf.reset_default_graph()
with tf.Session(graph=tf.get_default_graph()) as sess:
new_saver = tf.train.import_meta_graph('e:/mytrain/results/20190227_01/model/model.meta')
new_saver.restore(sess, 'e:/mytrain/results/20190227_01/model/model')
#加载进来之后还不是为所欲为
var_list=tf.global_variables()
print(var_list)
2.2直接对离线文件进行读取
使用内置的checkpointReader
但是注意,这种方法获取到的var_list与要求的格式不一样。
我还不知道怎么转换成saver要求的类型,等一个有缘人相告!
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
#文件夹地址改成自己的
model_dir="'e:\\20190227_01\\mytrain\\results\\20190227_01\\model"
ckpt = tf.train.get_checkpoint_state(model_dir)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
#返回一个dict= {'name':[shape] }
#例如 'd_w2/Adam':[4, 4, 32, 64]
var_to_shape_map = reader.get_variable_to_shape_map()
#我们可以用遍历的方式,取出字典里所有的key
for key in var_to_shape_map:
print(key) #key是str类型的
#再用key去找这个tensor的值
a=reader.get_tensor(key)
print(type(a)) #输出: <class 'numpy.ndarray'>
//参考https://blog.csdn.net/wc781708249/article/details/78040735
三、外部保存
思路是把获得的list输出到一个外部txt文件中
以后使用只需要读取txt即可
PS:目前这种方法只支持str类型保存变量name!不支持saver!
代码示例:
with open("var_list.txt","w") as f:
f.write( ','.join(str(var_list))) #write只能打印str类型,需要强制转换
效果如图:
要使用的时候,只需要按如下代码读取
with open("var_list.txt","r") as f:
str1 = f.readlines()[0]
var_list=str1.split(',')
print(var_list)