tensorflow模型加载与保存的两种方式

当我们训练模型时希望保存模型以便继续训练或者发布,总之,模型加载与保存是经常用到的。

1.第一种加载与保存方法

1.1 保存

import tensorflow as tf
import numpy as np
#定义图
...
istraing = tf.placeholder(tf.bool,name='istraing')
...
with tf.name_scope('loss'):
	total_loss = loss1 + loss2 + loss3
#其他操作等
...
# 定义Saver用于保存模型,如果 tf.train.Saver()的()不写东西,则默认保存图中所有,填东西就保存填进去的
saver = tf.train.Saver()
...
#开启会话
with tf.Session() as sess:
	#训练等
	...
	# 保存模型
    saver.save(sess,'models/my_model.ckpt')

这种保存需要这两句代码实现(saver = tf.train.Saver()和saver.save(sess,‘models/my_model.ckpt’)),而且模型保存的后缀一般定义为.ckpt(习惯但可以不遵守),最后在相应文件夹会产生.ckpt.data-XXX(保存模型参数)、.ckpt.meta-XXX(保存模型结构)、.ckpt.index.XXX(据说跟.ckpt.data-XXX一起保存模型参数)
最后这种模型保存方式的特点是:加载.ckpt文件等后可以继续训练

1.2 加载1

import tensorflow as tf
import numpy as np
#其他操作等
...
#开启会话
with tf.Session() as sess:
    # 载入模型结构
    saver = tf.train.import_meta_graph('models/my_model.ckpt.meta')
    # 载入模型参数
    saver.restore(sess,'models/my_model.ckpt')
    total_loss = sess.graph.get_tensor_by_name('loss/total_loss/add_1:0')
    #继续训练
    ...

使用saver = tf.train.import_meta_graph(‘models/my_model.ckpt.meta’)和saver.restore(sess,‘models/my_model.ckpt’)这两句即可加载模型与参数,后面调用sess.graph.get_tensor_by_name与sess.graph.get_operation_by_name并根据保存模型前定义的name得到模型中的tensor与operation,然后再sess.run()它们可以继续推断、训练与转化为第二种保存方式(.pb文件)

1.2 加载2

有可能你从网上得到只有.ckpt.data-XXX没有.ckpt.meta-XXX,这意味着没有模型结构,此时你在加载1的基础上需要先定义好对应的模型,然后saver.restore(sess,‘models/my_model.ckpt’)载入模型参数,之后也可以继续推断、训练与转化为第二种保存方式(.pb文件)。

2.第二种加载与保存方法

2.1 保存

import tensorflow as tf
import numpy as np
#定义图
...
istraing = tf.placeholder(tf.bool,name='istraing')
...
with tf.name_scope('loss'):
	total_loss = loss1 + loss2 + loss3
#其他操作等
...

#开启会话
with tf.Session() as sess:
	#训练等
	...
    # 保存模型参数和结构,把变量变成常量
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output','accuracy'])
    # 保存模型到目录下的models文件夹中
    with tf.gfile.FastGFile('pb_models/my_model.pb',mode='wb') as f:
        f.write(output_graph_def.SerializeToString())

这种保存方式模型是保存为.pb文件的,之后加载只能推断不能继续训练,但文件大小比前一种保存方式得到的文件小。

2.2 加载

import tensorflow as tf
import numpy as np

#载入模型
with tf.gfile.FastGFile('pb_models/my_model.pb', 'rb') as f:
    # 创建一个图
    graph_def = tf.GraphDef()
    # 把模型文件载入到图中
    graph_def.ParseFromString(f.read())
    # 载入图到当前环境中
    tf.import_graph_def(graph_def, name='')
    with tf.Session() as sess:
    # 根据tensor的名字获取到对应的tensor
    # 之前保存模型的时候模型输出保存为output,":0"是保存模型参数时自动加上的,所以这里也要写上
    output = sess.graph.get_tensor_by_name('output:0')
    # 根据tensor的名字获取到对应的tensor
    # 之前保存模型的时候准确率计算保存为accuracy,":0"是保存模型参数时自动加上的,所以这里也要写上
    accuracy = sess.graph.get_tensor_by_name('accuracy:0')
    # 预测准确率
    print(sess.run(accuracy,feed_dict={'x-input:0':mnist.test.images,'y-input:0':mnist.test.labels}))

.pb模型的加载方式是固定的,一般直接使用即可。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值