Tensorflow 获取model中的变量列表

前言

 

需求来自声明saver时需要指定变量列表。

var_list = saver._var_list
print(type(var_list[0]))

#output : <class 'tensorflow.python.ops.variables.RefVariable'>

print(var_list[0])
#output : <tf.Variable 'g_b1:0' shape=(32768,) dtype=float32_ref>

也就是说,我们也需要凑成这种类型的var_list才能输入给saver。

直接用str类型的name组成一个list是无效的!!!

 

1.动态获取

 

1.1两种最朴素的方法

朴实而无华,好记又好用。

#朴素获取可训练变量
t_vars = tf.trainable_variables()


#朴素获取全部变量,包含声明training=False的变量
all_vars = tf.global_variables()

 

2.2使用tensorflow.contrib.slim

import tensorflow.contrib.slim as slim

#下面这行代码返回常规变量
#常规变量是slim里面与model变量对应的一个类型
regular_variables = slim.get_variables()


#你也可以直接
vars = slim.get_variables_to_restore()

slim是一个非常强大的工具,有兴趣可以深入了解。

#slim还支持各种筛选方法

#通过name
variables = slim.get_variables_by_name("d_")

#通过name后缀
variables = slim.get_variables_by_suffix("_b")

#通过namespace
variables = slim.get_variables(scope="layer1")


#通过include和exclude筛选名字
variables_to_restore = slim.get_variables_to_restore(include=["d_"])

variables_to_restore = slim.get_variables_to_restore(exclude=["_w"])

 

//参考https://blog.csdn.net/guvcolie/article/details/77686555

 

二.离线获取

从一个已保存好的model中获取var_list。

 

2.1 将离线文件载入当前环境,于是问题变回了“动态获取”

#记住,要先清空现有的图
#不然的话import_meta_graph会把原model里面的数据追加到现有的model中
#一片混乱

tf.reset_default_graph()

with tf.Session(graph=tf.get_default_graph()) as sess:
    new_saver = tf.train.import_meta_graph('e:/mytrain/results/20190227_01/model/model.meta')
    new_saver.restore(sess, 'e:/mytrain/results/20190227_01/model/model')
    #加载进来之后还不是为所欲为
    var_list=tf.global_variables()
    print(var_list)

 

2.2直接对离线文件进行读取

使用内置的checkpointReader

但是注意,这种方法获取到的var_list与要求的格式不一样。

 我还不知道怎么转换成saver要求的类型,等一个有缘人相告!

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

#文件夹地址改成自己的
model_dir="'e:\\20190227_01\\mytrain\\results\\20190227_01\\model"

ckpt = tf.train.get_checkpoint_state(model_dir)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)

#返回一个dict= {'name':[shape] }
#例如 'd_w2/Adam':[4, 4, 32, 64]
var_to_shape_map = reader.get_variable_to_shape_map()

#我们可以用遍历的方式,取出字典里所有的key
for key in var_to_shape_map:
    print(key)        #key是str类型的
    #再用key去找这个tensor的值
    a=reader.get_tensor(key)
    print(type(a))    #输出: <class 'numpy.ndarray'>
    

//参考https://blog.csdn.net/wc781708249/article/details/78040735

 

 

三、外部保存

思路是把获得的list输出到一个外部txt文件中

以后使用只需要读取txt即可

PS:目前这种方法只支持str类型保存变量name!不支持saver!

 

代码示例:

with open("var_list.txt","w") as f:
        f.write( ','.join(str(var_list)))   #write只能打印str类型,需要强制转换 

效果如图:

 

要使用的时候,只需要按如下代码读取

with open("var_list.txt","r") as f:
        str1 = f.readlines()[0]
        var_list=str1.split(',')
        print(var_list)

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值