Tensorflow学习(持续更新)

26 篇文章 1 订阅
20 篇文章 0 订阅

1、Tensorflow读取并输出已保存模型的权重数值

(1)输出已保存模型的权重数值

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

#首先,使用tensorflow自带的python打包库读取模型
#此处的model.ckpt是forzon_inference_graph中的三个文件:
#model.ckpt.data-00000-of-00001、model.ckpt.index、model.ckpt.meta
model_reader = pywrap_tensorflow.NewCheckpointReader(r"model.ckpt")

#然后,使reader变换成类似于dict形式的数据
var_dict = model_reader.get_variable_to_shape_map()

#最后,循环打印输出
for key in var_dict:
    print("variable name: ", key)
    print(model_reader.get_tensor(key))

结果:

(2)输出已保存模型的权重名和shape

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

#首先,使用tensorflow自带的python打包库读取模型
model_reader = pywrap_tensorflow.NewCheckpointReader(r"model.ckpt")

#然后,使reader变换成类似于dict形式的数据
var_dict = model_reader.get_variable_to_shape_map()

#最后,循环打印输出
for key in var_dict:
    print("variable name: ", key)
    print(model_reader.get_tensor(key).shape)
print(len(var_dict))

结果:

2、打印tensorflow每一层结构

import tensorflow as tf

sess=tf.Session 
with tf.Graph().as_default(): 
    #pb文件与本代码文件位于一个文件夹下
    with tf.gfile.FastGFile('frozen_inference_graph.pb','rb') as modelfile: 
        graph_def=tf.GraphDef() 
        graph_def.ParseFromString(modelfile.read()) 
        tf.import_graph_def(graph_def) 
        [print(n.name) for n in tf.get_default_graph().as_graph_def().node]

结果:

3、keras模型导出成tf模型

github地址:https://github.com/amir-abdi/keras_to_tensorflow

命令行输入(h5和pb文件名自行设定):python keras_to_tensorflow.py --input_model="yolo.h5" --output_model="yolo.pb"

结果:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值