tf.get_default_graph()
功能:获取当前默认计算图。
import tensorflow as tf
a = tf.constant(1)
b = tf.constant(2)
c = a+b
print(c.graph)
print(tf.get_default_graph())
--------------------------------------------------------------
<tensorflow.python.framework.ops.Graph object at 0x00000000025A8BE0>
<tensorflow.python.framework.ops.Graph object at 0x00000000025A8BE0>
tf.train.Saver
功能: 保存模型、加载模型
import tensorflow as tf
v1 = tf.Variable(1, name='v1')
v2 = tf.Variable(2, name='v2')
# max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5
# 保留最近的5个模型,默认是5
# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# global_step provided the global step number is appended to save_path 将训练次数添加到模型名称上
saver.save(sess, 'main/data/mnist.ckpt', global_step=3)
ckpt = tf.train.get_checkpoint_state('main/data/')
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
---------------
checkpoint
mnist.ckpt-1.data-00000-of-00001
mnist.ckpt-1.index
mnist.ckpt-1.meta
main/data/mnist.ckpt-3
tf.summary.scalar
用来显示标量信息
# scalar(name, tensor, collections=None)
# 一般在画loss,accuary时会用到这个函数。
tf.summary.scalar('learning_rate', learning_rate)