tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测

import tensorflow as tf
  from tensorflow.python.frameworkimport graph_util
   
  filedir = "./models/model10.ckpt"
  output_graph = "./models/frozen_model.pb"
  saver = tf.train.import_meta_graph("./models/model10.ckpt.meta",clear_devices=True)
   
  graph = tf.get_default_graph()
  input_graph_def = graph.as_graph_def()
   
  with tf.Session() as sess:
  saver.restore(sess, filedir)
  for x in tf.global_variables():
  print x
  output_graph_def = graph_util.convert_variables_to_constants(
  sess,
  input_graph_def,
  ["M","rnn/lstm_cell/w_f_diag"])
  test = tf.get_default_graph().get_tensor_by_name("rnn/lstm_cell/biases/read:0").eval()
  print test
   
  with tf.gfile.GFile(output_graph,"wb")as f:
  f.write(output_graph_def.SerializeToString())


1. 训练模型

import tensorflow as tf

sess = tf.Session()
matrix_1 = tf.constant([3., 3.], name='input')
add = tf.add(matrix_1, matrix_1, name='output')
sess.run(add)

output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
# 保存模型到目录下的model文件夹中
with tf.gfile.FastGFile('./model/tensorflow_matrix_graph.pb',mode='wb') as f:
    f.write(output_graph_def.SerializeToString())

sess.close()  
 
 
# coding=utf8
 import tensorflow as tf
 from tensorflow.python.framework import graph_util
 # 1. pb文件的保存方法。
 v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
 v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
 result = v1 + v2
  
 init_op = tf.global_variables_initializer()
 with tf.Session() as sess:
 sess.run(init_op)
 graph_def = tf.get_default_graph().as_graph_def()
 output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
 with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
 f.write(output_graph_def.SerializeToString())
 # 2. 加载pb文件。
 from tensorflow.python.platform import gfile
  
 with tf.Session() as sess:
 model_filename = "Saved_model/combined_model.pb"
  
 with gfile.FastGFile(model_filename, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
  
 result = tf.import_graph_def(graph_def, return_elements=["add:0"])
 print sess.run(result)
 

唯一注意的一点是务必要保成pb格式的文件:

  1. 不能使用 tf.train.write_graph()保存模型,该种方式只是保存了模型的结构,并不保存训练完毕的参数值
  2. 不能使用 tf.train.saver()保存模型,该种方式只是保存了网络中的参数值,并不保存模型的结构。

我们需要的是既保存模型的结构,又保存模型中每个参数的值,所以上述的两种方式都不行:因此我们用一下方式保存:

# 可以把整个sesion当作常量都保存下来,通过output_node_names参数来指定输出
graph_util.convert_variables_to_constants
# 指定保存文件的路径以及读写方式
tf.gfile.FastGFile('model/test.pb', mode='wb')
# 将固化的模型写入到文件
f.write(output_graph_def.SerializeToString())  



ML主要分为训练和预测两个阶段,此教程就是将训练好的模型freeze并保存下来.freeze的含义就是将该模型的图结构和该模型的权重固化到一起了.也即加载freeze的模型之后,立刻能够使用了。

下面使用一个简单的demo来详细解释该过程,

一、首先运行脚本tiny_model.py

[python]  view plain  copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. import numpy as np  
  4.   
  5.   
  6. with tf.variable_scope('Placeholder'):  
  7.     inputs_placeholder = tf.placeholder(tf.float32, name='inputs_placeholder', shape=[None10])  
  8.     labels_placeholder = tf.placeholder(tf.float32, name='labels_placeholder', shape=[None1])  
  9.   
  10. with tf.variable_scope('NN'):  
  11.     W1 = tf.get_variable('W1', shape=[101], initializer=tf.random_normal_initializer(stddev=1e-1))  
  12.     b1 = tf.get_variable('b1', shape=[1], initializer=tf.constant_initializer(0.1))  
  13.     W2 = tf.get_variable('W2', shape=[101], initializer=tf.random_normal_initializer(stddev=1e-1))  
  14.     b2 = tf.get_variable('b2', shape=[1], initializer=tf.constant_initializer(0.1))  
  15.   
  16.     a = tf.nn.relu(tf.matmul(inputs_placeholder, W1) + b1)  
  17.     a2 = tf.nn.relu(tf.matmul(inputs_placeholder, W2) + b2)  
  18.   
  19.     y = tf.div(tf.add(a, a2), 2)  
  20.   
  21. with tf.variable_scope('Loss'):  
  22.     loss = tf.reduce_sum(tf.square(y - labels_placeholder) / 2)  
  23.   
  24. with tf.variable_scope('Accuracy'):  
  25.     predictions = tf.greater(y, 0.5, name="predictions")  
  26.     correct_predictions = tf.equal(predictions, tf.cast(labels_placeholder, tf.bool), name="correct_predictions")  
  27.     accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))  
  28.   
  29.   
  30. adam = tf.train.AdamOptimizer(learning_rate=1e-3)  
  31. train_op = adam.minimize(loss)  
  32.   
  33. # generate_data  
  34. inputs = np.random.choice(10, size=[1000010])  
  35. labels = (np.sum(inputs, axis=1) > 45).reshape(-11).astype(np.float32)  
  36. print('inputs.shape:', inputs.shape)  
  37. print('labels.shape:', labels.shape)  
  38.   
  39.   
  40. test_inputs = np.random.choice(10, size=[10010])  
  41. test_labels = (np.sum(test_inputs, axis=1) > 45).reshape(-11).astype(np.float32)  
  42. print('test_inputs.shape:', test_inputs.shape)  
  43. print('test_labels.shape:', test_labels.shape)  
  44.   
  45. batch_size = 32  
  46. epochs = 10  
  47.   
  48. batches = []  
  49. print("%d items in batch of %d gives us %d full batches and %d batches of %d items" % (  
  50.     len(inputs),  
  51.     batch_size,  
  52.     len(inputs) // batch_size,  
  53.     batch_size - len(inputs) // batch_size,  
  54.     len(inputs) - (len(inputs) // batch_size) * 32)  
  55. )  
  56. for i in range(len(inputs) // batch_size):  
  57.     batch = [ inputs[batch_size*i:batch_size*i+batch_size], labels[batch_size*i:batch_size*i+batch_size] ]  
  58.     batches.append(list(batch))  
  59. if (i + 1) * batch_size < len(inputs):  
  60.     batch = [ inputs[batch_size*(i + 1):],labels[batch_size*(i + 1):] ]  
  61.     batches.append(list(batch))  
  62. print("Number of batches: %d" % len(batches))  
  63. print("Size of full batch: %d" % len(batches[0]))  
  64. print("Size if final batch: %d" % len(batches[-1]))  
  65.   
  66. global_count = 0  
  67.   
  68. with tf.Session() as sess:  
  69. #sv = tf.train.Supervisor()  
  70. #with sv.managed_session() as sess:  
  71.     sess.run(tf.initialize_all_variables())  
  72.     for i in range(epochs):  
  73.         for batch in batches:  
  74.             # print(batch[0].shape, batch[1].shape)  
  75.             train_loss , _= sess.run([loss, train_op], feed_dict={  
  76.                 inputs_placeholder: batch[0],  
  77.                 labels_placeholder: batch[1]  
  78.             })  
  79.             # print('train_loss: %d' % train_loss)  
  80.   
  81.             if global_count % 100 == 0:  
  82.                 acc = sess.run(accuracy, feed_dict={  
  83.                     inputs_placeholder: test_inputs,  
  84.                     labels_placeholder: test_labels  
  85.                 })  
  86.                 print('accuracy: %f' % acc)  
  87.             global_count += 1  
  88.   
  89.     acc = sess.run(accuracy, feed_dict={  
  90.         inputs_placeholder: test_inputs,  
  91.         labels_placeholder: test_labels  
  92.     })  
  93.     print("final accuracy: %f" % acc)  
  94.     #在session当中就要将模型进行保存  
  95.     saver = tf.train.Saver()  
  96.     last_chkp = saver.save(sess, 'results/graph.chkp')  
  97.     #sv.saver.save(sess, 'results/graph.chkp')  
  98.   
  99. for op in tf.get_default_graph().get_operations():  
  100.     print(op.name)  
说明:saver.save必须在session里面,因为在session里面,整个图才是激活的,才能够将参数存进来,使用save之后能够得到如下的文件:


说明:
.data:存放的是权重参数
.meta:存放的是图和metadata,metadata是其他配置的数据
如果想将我们的模型固化,让别人能够使用,我们仅仅需要的是图和参数,metadata是不需要的

二、综合上述几个文件,生成可以使用的模型的步骤如下

1、恢复我们保存的图
2、开启一个Session,然后载入该图要求的权重
3、删除对预测无关的metadata
4、将处理好的模型序列化之后保存
运行freeze.py

[python]  view plain  copy
  1. #-*- coding:utf-8 -*-  
  2. import os, argparse  
  3. import tensorflow as tf  
  4. from tensorflow.python.framework import graph_util  
  5.   
  6. dir = os.path.dirname(os.path.realpath(__file__))  
  7.   
  8. def freeze_graph(model_folder):  
  9.     # We retrieve our checkpoint fullpath  
  10.     checkpoint = tf.train.get_checkpoint_state(model_folder)  
  11.     input_checkpoint = checkpoint.model_checkpoint_path  
  12.       
  13.     # We precise the file fullname of our freezed graph  
  14.     absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])  
  15.     output_graph = absolute_model_folder + "/frozen_model.pb"  
  16.   
  17.     # Before exporting our graph, we need to precise what is our output node  
  18.     # this variables is plural, because you can have multiple output nodes  
  19.     #freeze之前必须明确哪个是输出结点,也就是我们要得到推论结果的结点  
  20.     #输出结点可以看我们模型的定义  
  21.     #只有定义了输出结点,freeze才会把得到输出结点所必要的结点都保存下来,或者哪些结点可以丢弃  
  22.     #所以,output_node_names必须根据不同的网络进行修改  
  23.     output_node_names = "Accuracy/predictions"  
  24.   
  25.     # We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated  
  26.     clear_devices = True  
  27.       
  28.     # We import the meta graph and retrive a Saver  
  29.     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)  
  30.   
  31.     # We retrieve the protobuf graph definition  
  32.     graph = tf.get_default_graph()  
  33.     input_graph_def = graph.as_graph_def()  
  34.   
  35.     #We start a session and restore the graph weights  
  36.     #这边已经将训练好的参数加载进来,也即最后保存的模型是有图,并且图里面已经有参数了,所以才叫做是frozen  
  37.     #相当于将参数已经固化在了图当中   
  38.     with tf.Session() as sess:  
  39.         saver.restore(sess, input_checkpoint)  
  40.   
  41.         # We use a built-in TF helper to export variables to constant  
  42.         output_graph_def = graph_util.convert_variables_to_constants(  
  43.             sess,   
  44.             input_graph_def,   
  45.             output_node_names.split(","# We split on comma for convenience  
  46.         )   
  47.   
  48.         # Finally we serialize and dump the output graph to the filesystem  
  49.         with tf.gfile.GFile(output_graph, "wb") as f:  
  50.             f.write(output_graph_def.SerializeToString())  
  51.         print("%d ops in the final graph." % len(output_graph_def.node))  
  52.   
  53.   
  54. if __name__ == '__main__':  
  55.     parser = argparse.ArgumentParser()  
  56.     parser.add_argument("--model_folder", type=str, help="Model folder to export")  
  57.     args = parser.parse_args()  
  58.   
  59.     freeze_graph(args.model_folder)  

说明: 对于freeze操作,我们需要定义输出结点的名字.因为网络其实是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉.因为我们freeze模型的目的是接下来做预测.所以,一般情况下,output_node_names就是我们预测的目标 .

三、加载freeze后的模型,注意该模型已经是包含图和相应的参数了.所以,我们不需要再加载参数进来.也即该模型加载进来已经是可以使用了.

[python]  view plain  copy
  1. #-*- coding:utf-8 -*-  
  2. import argparse   
  3. import tensorflow as tf  
  4.   
  5. def load_graph(frozen_graph_filename):  
  6.     # We parse the graph_def file  
  7.     with tf.gfile.GFile(frozen_graph_filename, "rb") as f:  
  8.         graph_def = tf.GraphDef()  
  9.         graph_def.ParseFromString(f.read())  
  10.   
  11.     # We load the graph_def in the default graph  
  12.     with tf.Graph().as_default() as graph:  
  13.         tf.import_graph_def(  
  14.             graph_def,   
  15.             input_map=None,   
  16.             return_elements=None,   
  17.             name="prefix",   
  18.             op_dict=None,   
  19.             producer_op_list=None  
  20.         )  
  21.     return graph  
  22.   
  23. if __name__ == '__main__':  
  24.     parser = argparse.ArgumentParser()  
  25.     parser.add_argument("--frozen_model_filename", default="results/frozen_model.pb", type=str, help="Frozen model file to import")  
  26.     args = parser.parse_args()  
  27.     #加载已经将参数固化后的图  
  28.     graph = load_graph(args.frozen_model_filename)  
  29.   
  30.     # We can list operations  
  31.     #op.values() gives you a list of tensors it produces  
  32.     #op.name gives you the name  
  33.     #输入,输出结点也是operation,所以,我们可以得到operation的名字  
  34.     for op in graph.get_operations():  
  35.         print(op.name,op.values())  
  36.         # prefix/Placeholder/inputs_placeholder  
  37.         # ...  
  38.         # prefix/Accuracy/predictions  
  39.     #操作有:prefix/Placeholder/inputs_placeholder  
  40.     #操作有:prefix/Accuracy/predictions  
  41.     #为了预测,我们需要找到我们需要feed的tensor,那么就需要该tensor的名字  
  42.     #注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字  
  43.     x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')  
  44.     y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')  
  45.           
  46.     with tf.Session(graph=graph) as sess:  
  47.         y_out = sess.run(y, feed_dict={  
  48.             x: [[3574511111]] # < 45  
  49.         })  
  50.         print(y_out) # [[ 0.]] Yay!  
  51.     print ("finish")  
说明:

1、在预测的过程中,当把freeze后的模型加载进来后,我们只需要定义好输入的tensor和目标tensor即可

2、在这里要注意一下tensor_name和ops_name,

注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字

x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')一定要使用tensor的名字

3、要获取图中ops的名字和对应的tensor的名字,可用如下的代码:

[python]  view plain  copy
  1. # We can list operations  
  2. #op.values() gives you a list of tensors it produces  
  3. #op.name gives you the name  
  4. #输入,输出结点也是operation,所以,我们可以得到operation的名字  
  5. for op in graph.get_operations():  
  6.     print(op.name,op.values())  

=============================================================================================================================

上面是使用了Saver()来保存模型,也可以使用sv = tf.train.Supervisor()来保存模型

[python]  view plain  copy
  1. #-*- coding:utf-8 -*-  
  2. import tensorflow as tf  
  3. import numpy as np  
  4.   
  5.   
  6. with tf.variable_scope('Placeholder'):  
  7.     inputs_placeholder = tf.placeholder(tf.float32, name='inputs_placeholder', shape=[None10])  
  8.     labels_placeholder = tf.placeholder(tf.float32, name='labels_placeholder', shape=[None1])  
  9.   
  10. with tf.variable_scope('NN'):  
  11.     W1 = tf.get_variable('W1', shape=[101], initializer=tf.random_normal_initializer(stddev=1e-1))  
  12.     b1 = tf.get_variable('b1', shape=[1], initializer=tf.constant_initializer(0.1))  
  13.     W2 = tf.get_variable('W2', shape=[101], initializer=tf.random_normal_initializer(stddev=1e-1))  
  14.     b2 = tf.get_variable('b2', shape=[1], initializer=tf.constant_initializer(0.1))  
  15.   
  16.     a = tf.nn.relu(tf.matmul(inputs_placeholder, W1) + b1)  
  17.     a2 = tf.nn.relu(tf.matmul(inputs_placeholder, W2) + b2)  
  18.   
  19.     y = tf.div(tf.add(a, a2), 2)  
  20.   
  21. with tf.variable_scope('Loss'):  
  22.     loss = tf.reduce_sum(tf.square(y - labels_placeholder) / 2)  
  23.   
  24. with tf.variable_scope('Accuracy'):  
  25.     predictions = tf.greater(y, 0.5, name="predictions")  
  26.     correct_predictions = tf.equal(predictions, tf.cast(labels_placeholder, tf.bool), name="correct_predictions")  
  27.     accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))  
  28.   
  29.   
  30. adam = tf.train.AdamOptimizer(learning_rate=1e-3)  
  31. train_op = adam.minimize(loss)  
  32.   
  33. # generate_data  
  34. inputs = np.random.choice(10, size=[1000010])  
  35. labels = (np.sum(inputs, axis=1) > 45).reshape(-11).astype(np.float32)  
  36. print('inputs.shape:', inputs.shape)  
  37. print('labels.shape:', labels.shape)  
  38.   
  39.   
  40. test_inputs = np.random.choice(10, size=[10010])  
  41. test_labels = (np.sum(test_inputs, axis=1) > 45).reshape(-11).astype(np.float32)  
  42. print('test_inputs.shape:', test_inputs.shape)  
  43. print('test_labels.shape:', test_labels.shape)  
  44.   
  45. batch_size = 32  
  46. epochs = 10  
  47.   
  48. batches = []  
  49. print("%d items in batch of %d gives us %d full batches and %d batches of %d items" % (  
  50.     len(inputs),  
  51.     batch_size,  
  52.     len(inputs) // batch_size,  
  53.     batch_size - len(inputs) // batch_size,  
  54.     len(inputs) - (len(inputs) // batch_size) * 32)  
  55. )  
  56. for i in range(len(inputs) // batch_size):  
  57.     batch = [ inputs[batch_size*i:batch_size*i+batch_size], labels[batch_size*i:batch_size*i+batch_size] ]  
  58.     batches.append(list(batch))  
  59. if (i + 1) * batch_size < len(inputs):  
  60.     batch = [ inputs[batch_size*(i + 1):],labels[batch_size*(i + 1):] ]  
  61.     batches.append(list(batch))  
  62. print("Number of batches: %d" % len(batches))  
  63. print("Size of full batch: %d" % len(batches[0]))  
  64. print("Size if final batch: %d" % len(batches[-1]))  
  65.   
  66. global_count = 0  
  67.   
  68. #with tf.Session() as sess:  
  69. sv = tf.train.Supervisor()  
  70. with sv.managed_session() as sess:  
  71.     #sess.run(tf.initialize_all_variables())  
  72.     for i in range(epochs):  
  73.         for batch in batches:  
  74.             # print(batch[0].shape, batch[1].shape)  
  75.             train_loss , _= sess.run([loss, train_op], feed_dict={  
  76.                 inputs_placeholder: batch[0],  
  77.                 labels_placeholder: batch[1]  
  78.             })  
  79.             # print('train_loss: %d' % train_loss)  
  80.   
  81.             if global_count % 100 == 0:  
  82.                 acc = sess.run(accuracy, feed_dict={  
  83.                     inputs_placeholder: test_inputs,  
  84.                     labels_placeholder: test_labels  
  85.                 })  
  86.                 print('accuracy: %f' % acc)  
  87.             global_count += 1  
  88.   
  89.     acc = sess.run(accuracy, feed_dict={  
  90.         inputs_placeholder: test_inputs,  
  91.         labels_placeholder: test_labels  
  92.     })  
  93.     print("final accuracy: %f" % acc)  
  94.     #在session当中就要将模型进行保存  
  95.     #saver = tf.train.Saver()  
  96.     #last_chkp = saver.save(sess, 'results/graph.chkp')  
  97.     sv.saver.save(sess, 'results/graph.chkp')  
  98.   
  99. for op in tf.get_default_graph().get_operations():  
  100.     print(op.name)  

注意:使用了sv = tf.train.Supervisor(),就不需要再初始化了,将sess.run(tf.initialize_all_variables())注释掉,否则会报错.


在tensorflow中,graph是训练的核心,当一个模型训练完成后,需要将模型保存下来,一个通常的操作是:

variables = tf.all_variables()
                saver = tf.train.Saver(variables)
                saver.save(sess, "data/data.ckpt")
tf.train.write_graph(sess.graph_def, 'graph', 'model.ph', False)
 
 
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

这样就可以将model保存在model.ph文件中,然而使用的时候不仅要加载模型文件model.ph,还要加载保存的data.ckpt数据文件才能使用。这样保持了数据与模型的分离,确实是个不错的方法。 
当我们把一个训练模型完整的训练好上线时候,我们期待的场景是:将一张图片喂进去,然后得出结果。 这时候再这样加载或许有些不必要,特别是在一些变量”不明”的时候特别麻烦.这时候一个比较好的方法就是将变量(偏执,权重等)固化到模型数据中。

创建图

在文件开头增加如下代码 
这里写图片描述

声明tensor

在需要的操作添加 
这里写图片描述

固化保存

这里写图片描述

固化操作中最重要的函数是:

tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None)
 
 
  • 1
  • 1

代码运行后控制台打印: 
这里写图片描述
这样在我们使用的时候就不要再进行data.ckpt的数据恢复。直接通过:

sess.graph.get_tensor_by_name()
 
 
  • 1
  • 1

就可以获取一个tensor,是不是很方便。

小报错:


  • 5
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
训练模型成pb文件: 首先,你需要先定义并训练好一个Tensorflow模型。在模型训练完成后,你可以使用Tensorflow的`freeze_graph.py`脚本将模型转换为pb文件。 在终端中输入以下命令: ``` python tensorflow/python/tools/freeze_graph.py \ --input_graph=<input_graph.pb> \ --input_checkpoint=<input_checkpoint> \ --output_graph=<output_graph.pb> \ --output_node_names=<output_node_names> ``` 其中: - `<input_graph.pb>`:模型的GraphDef文件。 - `<input_checkpoint>`:模型的checkpoint文件。 - `<output_graph.pb>`:转换后的pb文件的输出路径。 - `<output_node_names>`:输出节点的名称,可以在定义模型时指定。 例如: ``` python tensorflow/python/tools/freeze_graph.py \ --input_graph=./models/graph.pb \ --input_checkpoint=./models/model.ckpt \ --output_graph=./models/frozen_graph.pb \ --output_node_names=output_node ``` 这将把`graph.pb`和`model.ckpt`转换为`frozen_graph.pb`文件。其中`output_node`是模型定义时指定的输出节点名称。 加载已经训练好的模型文件: 要加载已经训练好的模型文件,你需要使用Tensorflow的`tf.Session()`来创建一个会话,并使用`tf.train.import_meta_graph()`方法将模型的MetaGraph文件导入到当前的计算中。然后,你可以使用`tf.get_default_graph()`方法获取默认的计算,并使用`get_tensor_by_name()`方法获取模型中的张量。 以下是一个加载已经训练好的模型文件的示例代码: ``` import tensorflow as tf # 创建一个会话 sess = tf.Session() # 加载MetaGraph文件 saver = tf.train.import_meta_graph('./models/model.ckpt.meta') # 恢复变量 saver.restore(sess, './models/model.ckpt') # 获取默认计算 graph = tf.get_default_graph() # 获取模型中的张量 input_tensor = graph.get_tensor_by_name('input:0') output_tensor = graph.get_tensor_by_name('output:0') ``` 在这个例子中,我们使用`saver.restore()`方法恢复了模型的变量,然后获取了模型中的`input`和`output`张量。这里`input`和`output`是在定义模型时所命名的张量名称。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值