参考内容都出自于官方API教程tf.Session
一、Session类基本使用方法
这里使用的是1.15版本,TF官方为了能够在2.0+版本中保持兼容,因此调用时使用了tf.compat.v1.Session。
定义:一个Session是对环境的封装,环境中包含执行/executed过的Operation和评估/evaluated过的Tensor。
一个Session会含有很多资源,例如Variable、QueueBase、RenderBase等。当Session运行结束后需要通过**Session().close()**方法释放资源。
# Build a graph.
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
sess = tf.Session() # Launch the graph in a session.
print(sess.run(c)) # Evaluate the tensor `c`.
sess.close() # 释放资源
# context manager形式,with结束自动关闭sess
with tf.Session() as sess:
sess.run(...)
Session在创建时,构造函数__init__可以指定三个参数:
- target:一般用于分布式TF中,用于连接执行引擎;
- graph:此Session要launch的图,不指定则为默认;
- config:是一个protocol buffer:ConfigProto
# Launch the graph in a session that allows soft device placement and logs the placement decisions. 官方实例
sess = tf.Session(config=tf.ConfigProto(
allow_soft_placement=True, # 对设备采用柔性约束
log_device_placement=True) # 记录信息
)
# 设定GPU增量使用模式,不要全占满GPU:
tfconfig = tf.ConfigProto(allow_soft_placement=True) # set device auto
tfconfig.gpu_options.allow_growth = True # mem increase
with tf.Session(config=tfconfig) as sess:
二、Properties
Session中有很多properties,这些可以直接通过Session().调用查看内部的值。此处介绍Session中关于graph的两个properties:
- graph:返回此Session中已经launch的graph。
- graph_def:将构建的TF图以串行化方式显示出来。
a = tf.constant(1.0)
sess = tf.Session()
assert sess.graph == tf.get_default_graph()
print(sess.graph) # <tensorflow.python.framework.ops.Graph object at 0x7f4a672ac9d0>
print(sess.graph_def) # 返回图的串行化表示:
# node {
# name: "Const"
# op: "Const"
# ...
# versions {
# producer: 22
# }
二、Methods
并没有全列出来,见一个记录一个:
1.as_default()
返回一个context manager将当前Session设置为默认。一般结合with语句进行使用。as_default()语句并不会结束Session,必须调用close()方法手动结束。
c = tf.constant(1.0)
tfconfig = tf.ConfigProto(allow_soft_placement=True) # set device auto
tfconfig.gpu_options.allow_growth = True # mem increase
sess = tf.Session(config=tfconfig)
with sess.as_default():
assert tf.get_default_session() is sess
print tf.get_default_session() # <tensorflow.python.client.session.Session object at 0x7f314d0b3a50>
print(c.eval()) # 1.0 张量使用eval方法进行求值
print(sess.run(c)) # 1.0 也可以使用run进行求值
print(c) # Tensor("Const:0", shape=(), dtype=float32)
在会话中进行运算并取值一共有三种:
- tf.Operation.run():针对操作,Run operations
- tf.Tensor.eval():针对张量,Evaluate tensors
- sess.run():在全局空间取值,下部分会讲
2.run()
用于运行Operations和评估Tensors。
run(
fetches, # 要运行的Op和要评估的Tensor,列表传入
feed_dict=None, # 投喂的数据
options=None,
run_metadata=None
)
三、应用实例:多Graph多Session
import tensorflow as tf
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # set default GPU:0
tfconfig = tf.ConfigProto(allow_soft_placement=True) # set device auto
tfconfig.gpu_options.allow_growth = True # mem increase
g1 = tf.Graph() # Graph 1: actor net
g_ = tf.Graph() # L_Net
g2 = tf.Graph() # Graph 2: disturber net
with g1.as_default():
a = tf.constant(2.0)
b = tf.constant(3.0)
c = tf.constant(4.0)
d = tf.multiply(a, b) + c
with g2.as_default():
e = tf.constant(4.0)
f = tf.constant(6.0)
g = tf.constant(6.0)
h = tf.multiply(e, f) + g
with g_.as_default():
i = tf.placeholder(dtype=tf.float32, shape=[], name='G_Input')
j = i + 100.0
sess1 = tf.Session(graph=g1, config=tfconfig)
sess_ = tf.Session(graph=g_, config=tfconfig)
sess2 = tf.Session(graph=g2, config=tfconfig)
print('result of Net1: ', sess1.run(d)) # 10.0
print('result of Net2: ', sess2.run(h)) # 30.0
print(sess_.run(fetches=j, feed_dict={i: sess1.run(d)})) # Take the output of Net1 as the input of L_Net
print(sess_.run(fetches=j, feed_dict={i: sess2.run(h)})) # Take the output of Net2 as the input of L_Net