python如何训练模型-python – 如何在tensorflow中训练后使用模型(保存/加载图)

我的张量流版本是0.11.

我希望在训练后保存图形或保存tensorflow可以加载的其他东西.

我/使用导出和导入MetaGraph

X = tf.placeholder("float", [None, 28, 28, 1], name="X")

Y = tf.placeholder("float", [None, 10], name="Y")

tf.train.Saver()

with tf.Session() as sess:

...run something ...

final_tensor = tf.nn.softmax(py_x, name="final_result")

tf.add_to_collection("final_tensor", final_tensor)

predict_op = tf.argmax(py_x, 1)

tf.add_to_collection("predict_op", predict_op)

saver.save(sess, "my_project")

然后我运行load.py:

with tf.Session() as sess:

new_saver = tf.train.import_meta_graph("my_project.meta")

new_saver.restore(sess, "my_project")

predict_op = tf.get_collection("predict_op")[0]

for i in range(2):

test_indices = np.arange(len(teX)) # Get A Test Batch

np.random.shuffle(test_indices)

test_indices = test_indices[0:test_size]

print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==

sess.run(predict_op, feed_dict={"X:0": teX[test_indices],

"p_keep_conv:0": 1.0,

"p_keep_hidden:0": 1.0})))

但它返回错误

Traceback (most recent call last):

File "load_05_convolution.py", line 62, in

"p_keep_hidden:0": 1.0})))

File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 717, in run

run_metadata_ptr)

File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 894, in _run

% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))

ValueError: Cannot feed value of shape (256, 784) for Tensor u"X:0", which has shape "(?, 28, 28, 1)"

我真的不知道为什么?

如果我添加final_tensor = tf.get_collection(“final_result”)[0]

它返回另一个错误:

Traceback (most recent call last):

File "load_05_convolution.py", line 46, in

final_tensor = tf.get_collection("final_result")[0]

IndexError: list index out of range

是因为tf.add_to_collection只包含一个占位符吗?

II /使用tf.train.write_graph

我将此行添加到save.py的末尾

tf.train.write_graph(graph,’folder’,’train.pb’)

它成功创建了文件’train.pb’

with tf.gfile.FastGFile("folder/train.pb", "rb") as f:

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

_ = tf.import_graph_def(graph_def, name="")

with tf.Session() as sess:

predict_op = sess.graph.get_tensor_by_name("predict_op:0")

for i in range(2):

test_indices = np.arange(len(teX)) # Get A Test Batch

np.random.shuffle(test_indices)

test_indices = test_indices[0:test_size]

print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==

sess.run(predict_op, feed_dict={"X:0": teX[test_indices],

"p_keep_conv:0": 1.0,

"p_keep_hidden:0": 1.0})))

然后它返回错误:

Traceback (most recent call last):

File "load_05_convolution.py", line 22, in

graph_def.ParseFromString(f.read())

File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/message.py", line 185, in ParseFromString

self.MergeFromString(serialized)

File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1085, in MergeFromString

raise message_mod.DecodeError("Unexpected end-group tag.")

google.protobuf.message.DecodeError: Unexpected end-group tag.

你会介意分享标准方式,代码或教程来保存/加载模型吗?我真的很困惑.

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值