TensorFlow会话
会话:一个运行TensorFlow operation的类,会话包含以下两种:
- tf.Session:用于完整的程序当中
- tf.InteractiveSession:用于交互式上下文中的TensorFlow,例如shell
1、张量具体值的查看
- 法1:在会话当中运行后查看
- 法2:快速查看某个张量具体的值,在会话当中使用eval()函数来获取
注:eval()函数仅限于在会话当中使用
2、交互式会话使用
在交互式会话当中,涉及到tf.InteractiveSession()的使用,他会开启一个临时的会话
3、会话的相关操作
(1)会话的创建
__init__(target='', graph=None, config=None)
- 会话是拥有资源的,会话可能拥有的资源,如:tf.Variable、tf.QueueBase和 tf.ReaderBase。当这些资源不再需要的时候,释放这些资源非常重要。因此调用tf.Session.close会话中的方法,或将会话用作上下文管理器
需要注意:
- 会话掌握有资源,用完要回收,所以使用上下文管理器(with)
- 初始化会话对象的参数
- graph=None:指定到底要运行哪个图
- target:如果将此参数留空,会话将仅使用本地计算机中的设备,可以指定grpc://网址,来指定 TensorFlow 服务器的地址,这将会使得会话可以访问该服务器控制的计算机上的所有设备
- config:此参数允许指定一个 tf.ConfigProto 以便控制会话的行为,例如ConfigProto协议用于打印设备使用信息
使用config打印设备使用情况案例:
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def session_demo():
a = tf.constant(100, name="a_t")
b = tf.constant(200, name="b_t")
c = tf.add(a, b, name="c_t")
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True,
log_device_placement=True)) as sess:
# c_t = sess.run(c)
print("default_g attr: ", tf.compat.v1.get_default_graph())
print("c_t: ", c.eval())
print("c graph attr: ", c.graph)
tf.compat.v1.summary.FileWriter("./temp/sunmmary", graph=sess.graph)
if __name__ == '__main__':
session_demo()
(2)会话的run()
run(fetches, feed_dict=None, options=None, run_metadata=None)
- 通过使用 sess.run()方法来运行 operation
- fetches:单一的operation,或者列表、元组(其他不属于tensorflow的类型不行)
- feed_dict:参数允许调用者覆盖途中张量的值,运行时赋值
- feed_dict传递的是字典类型
- 与 tf.placeholder搭配使用,则会检查值的形状是否与占位符兼容
- feed_dict不能单独使用,必须和tf.placeholder搭配使用
featches使用列表案例:
def session_list():
a = tf.constant(100)
b = tf.constant(200)
c = tf.add(a, b)
with tf.compat.v1.Session() as sess:
list = sess.run([a, b, c])
print("list: ", list)
if __name__ == '__main__':
session_list()
feed_dict使用案例:
def feed_demo():
#placeholder就是在不知道具体值的情况下,先定义出来,用于占位
a = tf.compat.v1.placeholder(tf.float32)
b = tf.compat.v1.placeholder(tf.float32)
sum_ab = tf.add(a, b)
with tf.compat.v1.Session() as sess:
print("占位符的结果是:", sess.run(sum_ab, feed_dict={a: 3.33, b: 4.44}))
return None
if __name__ == '__main__':
feed_demo()