tensorflow模型保存和加载

  • 方式一ckpt 使用saver:如果要看,请直接看这个方法的最后两个,前面讲的是官网的坑。

用tf.train.Saver()创建一个saver,然后保存。

# 建立网络创建变量,建图
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 建立saver对象,后面保存的时候要用
saver = tf.train.Saver()


with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())#初始化网络中的变量
  ..
  # 将模型保存起来
  save_path = saver.save(sess, "/tmp/model.ckpt")

模型的加载:

#定义一个saver
saver = tf.train.Saver()

with tf.Session() as sess:
  # 用saver加载
  saver.restore(sess, "/tmp/model.ckpt")

所以,tf.train.Saver真的很简单,定义一下,然后保存就.save,加载就.restore。看起来真美好啊,官网总是这么忽悠人。

事实上,光定义一个saver就会出错,提示没有保存的变量。我的tf 是1.14

# 建立saver对象,后面保存的时候要用
saver = tf.train.Saver()

既然说了缺少保存的变量,那就加如待保存的变量

保存:注意,在 tf.train.Saver的参数里面加了tf.global_variables()这样可以获取所有的全局变量

# 建立网络创建变量,建图
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 建立saver对象,后面保存的时候要用
saver = tf.train.Saver(tf.global_variables())


with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())#初始化网络中的变量
  ..
  # 将模型保存起来
  save_path = saver.save(sess, "/tmp/model.ckpt")

加载:

ckpt_model_path = 'ckpt_model'#checkpoint保存的目录
ckpt = tf.train.get_checkpoint_state(ckpt_model_path)#加载checkpoint中保存的模型的文件名
#定义一个saver
#不再是saver = tf.train.Saver()了而是换一种定义方式,从保存的文件中加载
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path+'.meta')

with tf.Session() as sess:
  # 用saver加载
  saver.restore(sess, ckpt.model_checkpoint_path)

所以,是需要在定义saver的时候指定要保存的变量,在加载saver时,使用tf.train.import_meta_graph(ckpt.model_checkpoint_path+'.meta')

.meta中存的是图的结构,这样也就解决了初始化saver没有变量的问题。所以该方法是加载了图的结构以及各变量名称和数据。

也可以不加载图结构,可以在新图中改变batch size等变量的的值。但是需要在定义saver之前定义好新图的结构。

综上,用saver很方便,定义,然后保存的时候调saver的save,加载时调saver的restore。唯一麻烦的是saver在定义的时候,需要指定变量。保存的时候,指定的变量可以是tf.global_variables(),加载的时候,定义的变量就需要从checkpoint中加载

saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path+'.meta')。

  • 方法二:pb模型

pb模型是谷歌推荐的保存方式,pb模型语言独立、大小也比ckpt少很多。

pb模型的保存和加载方式有两种:

第一种:tf.graph_util

第二种:tf.saved_model

  • 2.1 tf.graph_util

保存:比较简单,就两步:

1.将图中的变量用常量代替

constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output_node_names'])

2.写入文件

tf.io.write_graph(constant_graph,pb_model_path,'model.pb',as_text=False)

如果不嫌麻烦,也可以

with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
    f.write(constant_graph.SerializeToString())

载入:比较复杂,但也就这么几行

读入保存的图文件,定义一个图,从读入的数据中解析出来图

#从保存的文件中读出数据
with open(pb_model_path, 'rb') as f:
    data = f.read()
#定义一个GraphDef对象
graph_def = tf.GraphDef()#这个对象中保存了图的节点信息,使用protocol buffer格式保存
graph_def.ParseFromString(data)
tf.import_graph_def(graph_def,name='')
  • 2.2 tf.saved_model

保存:这个过程比较复杂,要先定义一个builder,然后定义要保存的图的输入,基本上是placeholder,然后定义要保存的图的输出,将这些东西加入到signature中,把signature放到builder上,最后保存。其实还比较好理解了,定义builder、图的输入、输出,signature(把输入输出放在一起),把signature放到builder上,最后保存。

builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
inputs = {'input_x':tf.saved_model.build_tensor_info(model.input_x),
          'input_y':tf.saved_model.build_tensor_info(model.input_y),
          'keep_prob':tf.saved_model.build_tensor_info(model.keep_prob)}#定义图的输入,就是把网络中的placeholder放在这里,给各个placeholder建立信息。
outputs = {'output':tf.saved_model.build_tensor_info(model.predict_y)}#定义图的输出
signature = signature_def_utils.build_signature_def(inputs = inputs, outputs=outputs,method_name = signature_constants.PREDICT_METHOD_NAME)#把输入输出拼在一起。signature_def是把变量名和变量对应起来的操作,method name是指定classifier、prdict还是regress
legacy_init_op = tf.group(tf.tables_initializer(),name='legacy_init_op')#图中有部分操作是属于tf的tables的,比如lookup,参考https://stackoverflow.com/questions/54540018/what-does-tensorflows-tables-initializer-do
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING],signature_def_map = {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:signature},legacy_init_op = legacy_init_op)
builder.save()

载入:比较简单就一句话tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.SERVING],saved_model_dir),载入后取出需要用的变量名,根据变量名取出变量,就可以做inference了

#载入pb模型
meta_graph_def = tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.SERVING],saved_model_dir)
#开始解析模型,signature def中保存了变量名和变量之间的对应关系
signature = meta_graph_def.signature_def
#获取所需要的变量名
input_x_name = signature[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['input_x'].name
keep_prob_name = signature[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['keep_prob'].name
pred_y_name = signature[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['output'].name
#根据变量名获取图中的变量
input_x = sess.graph.get_tensor_by_name(input_x_name)
keep_prb = sess.graph.get_tensor_by_name(keep_prob_name)
pred_y = sess.graph.get_tensor_by_name(pred_y_name)
#解析完毕,开始运行
with sess.graph.as_default():
    sess.run(tf.global_variables_initilaizer())
    pred = sess.run(pred_y, feed_dict = {input_x:data})

补充:上述步骤还是比较啰嗦的。tensorflow推出了简洁款的。可参考https://blog.csdn.net/weixin_33912638/article/details/89140815

就目前我的经验而言,更推荐2.2虽然看着罗嗦很多但更加可靠。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值