with tf.Session() as sess:
sess.run(init)
for epoch in range(3):
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print("Iter " + str(epoch) + ", Testing Accuracy " + str(acc))
saver =tf.train.Saver()
saver.save(sess, "C:\\Users\\Albert\\Desktop\\tensorflow\\albertModel")#路径可以自己定
如上代码所示,在训练完毕后,实例化Saver将Session会话进行保存到指定目录下的文件。
该路径会被添加至checkpoint文件中,当需要恢复(加载)模型时,只需如下进行:
with tf.Session() as sess:
saver.restore(sess, "C:\\Users\\Albert\\Desktop\\tensorflow\\albertModel")
W = sess.run(W1)
for i in range(100):
print ("weights", W[i])
print ("b", sess.run(b1))
新建一个会话,并将之前保存的会话赋予它即可