我的张量流版本是0.11.
我希望在训练后保存图形或保存tensorflow可以加载的其他东西.
我/使用导出和导入MetaGraph
X = tf.placeholder("float", [None, 28, 28, 1], name="X")
Y = tf.placeholder("float", [None, 10], name="Y")
tf.train.Saver()
with tf.Session() as sess:
...run something ...
final_tensor = tf.nn.softmax(py_x, name="final_result")
tf.add_to_collection("final_tensor", final_tensor)
predict_op = tf.argmax(py_x, 1)
tf.add_to_collection("predict_op", predict_op)
saver.save(sess, "my_project")
然后我运行load.py:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph("my_project.meta")
new_saver.restore(sess, "my_project")
predict_op = tf.get_collection("predict_op")[0]
for i in range(2):
test_indices = np.arange(len(teX)) # Get A Test Batch
np.random.shuffle(test_indices)
test_indices = test_indices[0:test_size]
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
sess.run(predict_op, feed_dict={"X:0": teX[test_indices],
"p_keep_conv:0": 1.0,
"p_keep_hidden:0": 1.0})))
但它返回错误
Traceback (most recent call last):
File "load_05_convolution.py", line 62, in
"p_keep_hidden:0": 1.0})))
File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 894, in _run
% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (256, 784) for Tensor u"X:0", which has shape "(?, 28, 28, 1)"
我真的不知道为什么?
如果我添加final_tensor = tf.get_collection(“final_result”)[0]
它返回另一个错误:
Traceback (most recent call last):
File "load_05_convolution.py", line 46, in
final_tensor = tf.get_collection("final_result")[0]
IndexError: list index out of range
是因为tf.add_to_collection只包含一个占位符吗?
II /使用tf.train.write_graph
我将此行添加到save.py的末尾
tf.train.write_graph(graph,’folder’,’train.pb’)
它成功创建了文件’train.pb’
with tf.gfile.FastGFile("folder/train.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="")
with tf.Session() as sess:
predict_op = sess.graph.get_tensor_by_name("predict_op:0")
for i in range(2):
test_indices = np.arange(len(teX)) # Get A Test Batch
np.random.shuffle(test_indices)
test_indices = test_indices[0:test_size]
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
sess.run(predict_op, feed_dict={"X:0": teX[test_indices],
"p_keep_conv:0": 1.0,
"p_keep_hidden:0": 1.0})))
然后它返回错误:
Traceback (most recent call last):
File "load_05_convolution.py", line 22, in
graph_def.ParseFromString(f.read())
File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/message.py", line 185, in ParseFromString
self.MergeFromString(serialized)
File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1085, in MergeFromString
raise message_mod.DecodeError("Unexpected end-group tag.")
google.protobuf.message.DecodeError: Unexpected end-group tag.
你会介意分享标准方式,代码或教程来保存/加载模型吗?我真的很困惑.