-
前言
TensorFlow使用tf.Session()类来表示客户端程序(通常为Python程序,也提供了使用其他语言的类似接口)与C++运行时之间的连接,因为TensorFlow底层是用C++实现,会话提供了沟通它们的接口,使我们创建的数据和操作能运行起来。而且,tf.Session对象使用分布式TensorFlow运行时提供对本地计算机中的设备和远程设备的访问权限。
-
开启会话
一个运行TensorFlow operation的类,会包含以下两种开启方式
tf.Session:用于完整的程序当中
tf.InteractiveSession:用于交互式上下文中的TensorFlow,例如shell
先说下tf.InteractiveSession,tf.Session下面细说,两种开启方式说完后在说下区别
tf.InteractiveSession()是一种交互式的session方式,它让自己成为了默认的session,也就是说用户在不需要指明用哪个session运行的情况下,就可以运行起来。若在交互式环境中,就很方便,如ipython
比如,我们想要看某个张量具体的值,在图中我们是看不了的,要在会话中才能看到,
import tensorflow as tf
def session_demo():
a = tf.constant(5)
b = tf.constant(6)
sum_c = tf.add(a,b)
tf.InteractiveSession()
print(sum_c.eval())
return None
if __name__ == "__main__":
session_demo()
可以看到,我们在开启交互之后,直接通过eval()的方式就可以运行操作
————————————————————————————————————————————————————
接着说tf.Session,会话可能拥有的资源,如tf.Variable,tf.QueueBase和tf.ReaderBase。当这些资源不再需要时,释放这些资源非常重要。因此,需要调用tf.Session.close会话中的方法,或将会话作用与上下文管理器。
先看下会话中各参数的意义:
创建会话对象__init__(target=’’,graph=None,config= None)
target:如果将此参数留空(默认设置),会话将仅使用本地计算机中的设备。可以指定grppc://网址,以便指定TensorFlow服务器的地址,这使得会话可以访问服务器控制的计算机上的所有设备。
graph:默认情况下,新的tf.Session将绑定到当前的默认图,这个在我的第一篇博客中有提到,这里就不细说了
config:此参数允许您指定一个tf.ConfigProto以控制会话的行为,例如,ConfigProto协议用于打印设备使用信息。
这个举个简单例子就明白了
import tensorflow as tf
def session_demo():
a = tf.constant(5)
b = tf.constant(6)
sum_c = tf.add(a,b)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=True)) as sess:
sum_c_value = sess.run(sum_c)
print("sum_c_value:\n",sum_c_value)
return None
if __name__ == "__main__":
session_demo()
输出结果:
Device mapping: no known devices.
Add: (Add): /job:localhost/replica:0/task:0/device:CPU:0
Const: (Const): /job:localhost/replica:0/task:0/device:CPU:0
Const_1: (Const): /job:localhost/replica:0/task:0/device:CPU:0
sum_c_value:
11
其中:
allow_soft_placement = True : 如果你指定的设备不存在,允许TF自动分配设备
log_device_placement = True : 是否打印设备分配日志
自动打印了设备信息,还可以控制GPU资源等,相关的可自行百度
- 传统的会话定义与使用上下文管理器的比较:
传统的会话定义:
import tensorflow as tf
def session_demo():
a = tf.constant(5)
b = tf.constant(6)
sum_c = tf.add(a,b)
#开启会话
#传统的会话定义
sess = tf.Session()
sum_c_value = sess.run(sum_c)
print("sum_c_value:\n",sum_c_value)
sess.close()
return None
if __name__ == "__main__":
session_demo()
使用上下文管理器:
import tensorflow as tf
def session_demo():
a = tf.constant(5)
b = tf.constant(6)
sum_c = tf.add(a,b)
#采用上下文管理器
with tf.Session() as sess:
sum_c_value = sess.run(sum_c)
print("sum_c_value:\n",sum_c_value)
return None
if __name__ == "__main__":
session_demo()
两种方法输出结果相同,传统的会话定义方式需要自己调用close()方法来释放资源,如果我们过早的释放了资源,很可能会产生错误,我们一般将会话作用与上下文管理器,让其在合适的时候自动将其关闭。
————————————————————————————————————————————————————
两种启动方式说完了,来说下两者的区别:
tf.InteractiveSession()默认自己就是用户要操作的session;而tf.Session()没有这个默认,因此用eval()启动计算时需要指明session。
-
会话的run()
run(fetches,feed_dict = None,options=None,run_metadata=None)
会话创建好之后,通过sess.run()来运行Operation
fetches:单一的Operation,或者列表、元组(其他不属于TensorFlow的类型不行)
feed_dict:参加允许调用这覆盖图中张量的值,运行时赋值,与tf.placeholder搭配使用,则会检查值的形状是否与占位符兼容
使用tf.operation.eval()也可以运行operation,但需要在会话中运行
fetches传入多个值时,这里以列表为例:
import tensorflow as tf
def session_demo():
a_t = tf.constant(5)
b_t = tf.constant(6)
c_t = tf.add(a_t,b_t)
with tf.Session() as sess:
#同时查看a_t,b_t,c_t
abc = sess.run([a_t,b_t,c_t])
print("abc_1:\n",abc)
a,b,c= sess.run([a_t,b_t,c_t])
print("abc_2:\n",a,b,c)
return None
if __name__ == "__main__":
session_demo()
输出结果:
abc_1:
[5, 6, 11]
abc_2:
5 6 11
-
feed操作:
有的时候,我们在定义张量的时候,我们并不确定具体的值时什么,就可以用placeholder去定义,相当于先占了一个位置
placeholder提供占位符,定义时必须要填类型,run时候通过feed_dict指定参数,
import tensorflow as tf
def ph_demo():
#定义占位符,此时我们并不知道a_ph,b_ph的值
a_ph = tf.placeholder(tf.float32)
b_ph = tf.placeholder(tf.float32)
c_ph = tf.add(a_ph,b_ph)
print('a_ph:\n',a_ph)
print('b_ph:\n',b_ph)
print('c_ph:\n',c_ph)
if __name__ == "__main__":
ph_demo()
运行结果:
a_ph:
Tensor("Placeholder:0", dtype=float32)
b_ph:
Tensor("Placeholder_1:0", dtype=float32)
c_ph:
Tensor("Add:0", dtype=float32)
如果运行时不指定feed_dict:
import tensorflow as tf
def ph_demo():
#定义占位符
a_ph = tf.placeholder(tf.float32)
b_ph = tf.placeholder(tf.float32)
c_ph = tf.add(a_ph,b_ph)
print('a_ph:\n',a_ph)
print('b_ph:\n',b_ph)
print('c_ph:\n',c_ph)
with tf.Session() as sess:
#运行placeholder
c_ph_value = sess.run(c_ph)
print("c_ph_value:\n",c_ph_value)
if __name__ == "__main__":
ph_demo()
运行时会报错:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float
[[{{node Placeholder_1}} = Placeholder[dtype=DT_FLOAT, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
运行时指定feed_dict,如:
c_ph_value = sess.run(c_ph,feed_dict={a_ph:3.0,b_ph:4.0})
运行结果:
a_ph:
Tensor("Placeholder:0", dtype=float32)
b_ph:
Tensor("Placeholder_1:0", dtype=float32)
c_ph:
Tensor("Add:0", dtype=float32)
c_ph_value:
7.0
通过feed_dict传一个字典,告诉会话a_ph,b_ph的值;当然,传的数据必须和所定义的类型一致