[Tensorflow] 模型存储、查看与载入

[Tensorflow] 模型存储、查看与载入

版本: tensorflow-1.8.0
代码:Github

1. 模型存储

使用tf.train.Saver模块, 保存路径的URL名称一定要*.ckpt。

import tensorflow as tf

v1 = tf.get_variable("v1", shape=[1], initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0), dtype=tf.float32)
v2 = tf.get_variable("v2", shape=[1], initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0), dtype=tf.float32)
result = v1 + v2

init_op = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "/home/jagger/workspace/tmp/model.ckpt")

模型会保存会在“/home/jagger/workspace/tmp/“目录下出现三个文件

这里写图片描述

其中.meta保存的是模型图结构,.index保存的是模型参数的索引,*.data保存的是模型参数具体数值。

如果要保存不同训练次数的模型,可以这样

    saver.save(sess, "/home/jagger/workspace/tmp/model.ckpt", global_step=step)

这时候就会自动在.ckpt加上-次数,例如:

这里写图片描述

如果想只保存模型结构(Graph)

saver = tf.train.Saver()
saver.export_meta_graph("/home/jagger/workspace/tmp/model.ckpt.meta", as_text=True)  # 可用编辑器打开,为json格式

2. 参数查看

从.ckpt文件中查看模型参数值

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# print all variables in checkpoint file
print("=========All Variables==========")
print_tensors_in_checkpoint_file("/home/jagger/workspace/tmp/model.ckpt", tensor_name=None, all_tensors=True, all_tensor_names=True)

# print only tensor v1 in checkpoint file
print("=========V1==========")
print_tensors_in_checkpoint_file("/home/jagger/workspace/tmp/model.ckpt", tensor_name='v1', all_tensors=False, all_tensor_names=False)

# print only tensor v2 in checkpoint file
print("=========V2==========")
print_tensors_in_checkpoint_file("/home/jagger/workspace/tmp/model.ckpt", tensor_name='v2', all_tensors=False, all_tensor_names=False)

输出:

=========All Variables==========
tensor_name:  v1
[-0.46169969]
tensor_name:  v2
[ 0.40403476]
=========V1==========
tensor_name:  v1
[-0.46169969]
=========V2==========
tensor_name:  v2
[ 0.40403476]

如果模型参数文件是.ckpt-1000.index和.ckpt-1000.data形式,则输入的url也要加上-次数,例如

print_tensors_in_checkpoint_file("/home/jagger/workspace/tmp/model.ckpt-1000",
                                     tensor_name=None, all_tensors=True, all_tensor_names=True)

3. 模型参数载入

只载入模型参数,模型结构需要自己构建:

import tensorflow as tf

v1 = tf.get_variable("other-v1", shape=[1], initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0), dtype=tf.float32)
v2 = tf.get_variable("other-v2", shape=[1], initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0), dtype=tf.float32)
result = v1 + v2

saver = tf.train.Saver({"v1": v1, "v2": v2})  # 指定对应的参数,缺省则按参数名载入

with tf.Session() as sess:

    saver.restore(sess, "/home/jagger/workspace/tmp/model.ckpt")
    print(sess.run(result))

既载入模型结构(Graph),也载入参数值

import tensorflow as tf

# importing graph
saver = tf.train.import_meta_graph("/home/jagger/workspace/tmp/model.ckpt.meta")

with tf.Session() as sess:

    # loading variable value to sess
    saver.restore(sess, "/home/jagger/workspace/tmp/model.ckpt")
    result = tf.get_default_graph().get_tensor_by_name("add:0")
    print(sess.run(result))
  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值