TensorFlow 学习笔记 - Session
Session 拥有并管理 TensorFlow 程序运行时的所有资源。通常,会结合 with 语句使用,在计算完成之后,自动释放系统资源,除非手动关闭。
- 使用tf.Session.run() 获取计算结果:tf.Session.run(result)
""" 手动关闭 session """
sess = tf.Session()
variable_a = tf.constant([1.0, 2.0], name='a')
variable_b = tf.constant([2.0, 3.0], name='b')
result = variable_a + variable_b
# 输出:array([ 3., 5.], dtype=float32)
sess.run(result)
# 关闭 session
sess.close()
""" 自动关闭 Session """
with tf.Session() as sess:
variable_a = tf.constant([1.0, 2.0], name='a')
variable_b = tf.constant([2.0, 3.0], name='b')
result = variable_a + variable_b
# 输出:array([ 3., 5.], dtype=float32)
sess.run(result)
- 使用默认的 Session
- tf.Session().as_default()
- tf.get_default_session()
c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
print(c.eval())
# ...
with sess.as_default():
print(c.eval())
# 需要手动关闭 session,注意这里和在 with 语句里面直接获得 session 不同
sess.close()
使用 tf.Session().as_default() 之后,可以通过变量的 eval() 方法,计算变量的值。但是如果不是在 with 语句块中,会出现如下错误:
raise ValueError("Cannot evaluate tensor using `eval()`: No default "
ValueError: Cannot evaluate tensor using `eval()`: No default session is registered. Use `with sess.as_default()` or pass an explicit session to `eval(session=sess)`
默认的 session 是当前线程的一个属性,这里的当前线程一般就是主线程。当在主线程创建一个新的线程时,并且希望在新的线程里面使用默认的 session,此时必须在新线程中加入 sess.as_default() 语句。
使用 with sess.as_default() 并不意味着计算图 graph 就是 default graph,即 sess.graph 不一定是 tf.get_default_graph, 因为在当前的 session 中可能包含多个 graph。可以通过 sess.graph.as_default() 设置当前的 graph 是 Default graph.
交互式 session
当在交互式环境下,如在 terminal 中运行 Python 代码片段或者在 Jupyter 编辑器中,通过 tf.InteractiveSession 更加方便,它会自动将生成的会话注册为默认会话。
sess = tf.InteractiveSession()
variable_a = tf.constant([1.0, 2.0], name='a')
variable_b = tf.constant([2.0, 3.0], name='b')
print(result.eval())
sess.close()
session 配置文件
使用 tf.ConfigProto() 配置生成 session 的一些参数列表,如最大线程数、GPU 分配策略、运行超时时间等。最常用的参数有两个:allow_soft_placement 和 log_device_placement。
allow_soft_placement
allow_soft_placement 是一个布尔型的参数,默认为 False,通常设置为 True,以此获得更好的代码可移植性。当为 True 时,满足以下任意一个条件时,GPU 上的运算就可以放到 CPU 上运行:
1. 运算无法在 GPU 上运行;
2. 没有 GPU 资源 或者指定的 GPU 设备无法获得;
3. 运算输入包含对 CPU 计算结果的引用;
log_device_placement
log_device_placement 也是一个布尔型的值。当为 True 时,日志中会记录每个节点被安排在了哪个设备上以方便调试。当在生产环境下时,一般设置为 False,以减少日志量。