TensorFlow笔记——会话

  • 前言

TensorFlow使用tf.Session()类来表示客户端程序(通常为Python程序,也提供了使用其他语言的类似接口)与C++运行时之间的连接,因为TensorFlow底层是用C++实现,会话提供了沟通它们的接口,使我们创建的数据和操作能运行起来。而且,tf.Session对象使用分布式TensorFlow运行时提供对本地计算机中的设备和远程设备的访问权限。

  • 开启会话

一个运行TensorFlow operation的类,会包含以下两种开启方式

tf.Session:用于完整的程序当中

tf.InteractiveSession:用于交互式上下文中的TensorFlow,例如shell

先说下tf.InteractiveSession,tf.Session下面细说,两种开启方式说完后在说下区别

tf.InteractiveSession()是一种交互式的session方式,它让自己成为了默认的session,也就是说用户在不需要指明用哪个session运行的情况下,就可以运行起来。若在交互式环境中,就很方便,如ipython

比如,我们想要看某个张量具体的值,在图中我们是看不了的,要在会话中才能看到,

import tensorflow as tf 


def session_demo():

    a = tf.constant(5)
    b = tf.constant(6)
    sum_c = tf.add(a,b)

    tf.InteractiveSession()
    print(sum_c.eval())
        
    return None

if __name__ == "__main__":
    session_demo()

可以看到,我们在开启交互之后,直接通过eval()的方式就可以运行操作

————————————————————————————————————————————————————

接着说tf.Session,会话可能拥有的资源,如tf.Variable,tf.QueueBase和tf.ReaderBase。当这些资源不再需要时,释放这些资源非常重要。因此,需要调用tf.Session.close会话中的方法,或将会话作用与上下文管理器。

 

先看下会话中各参数的意义:

创建会话对象__init__(target=’’,graph=None,config= None)

target:如果将此参数留空(默认设置),会话将仅使用本地计算机中的设备。可以指定grppc://网址,以便指定TensorFlow服务器的地址,这使得会话可以访问服务器控制的计算机上的所有设备。

graph:默认情况下,新的tf.Session将绑定到当前的默认图,这个在我的第一篇博客中有提到,这里就不细说了

config:此参数允许您指定一个tf.ConfigProto以控制会话的行为,例如,ConfigProto协议用于打印设备使用信息。

这个举个简单例子就明白了

import tensorflow as tf 

def session_demo():

    a = tf.constant(5)
    b = tf.constant(6)
    sum_c = tf.add(a,b)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=True)) as sess:
        sum_c_value = sess.run(sum_c)
        print("sum_c_value:\n",sum_c_value)
            
    return None

if __name__ == "__main__":
    session_demo()

输出结果:

Device mapping: no known devices.
Add: (Add): /job:localhost/replica:0/task:0/device:CPU:0
Const: (Const): /job:localhost/replica:0/task:0/device:CPU:0
Const_1: (Const): /job:localhost/replica:0/task:0/device:CPU:0
sum_c_value:
 11

其中:

allow_soft_placement = True : 如果你指定的设备不存在,允许TF自动分配设备
log_device_placement = True : 是否打印设备分配日志

自动打印了设备信息,还可以控制GPU资源等,相关的可自行百度

 

  • 传统的会话定义与使用上下文管理器的比较:

传统的会话定义:

import tensorflow as tf 

def session_demo():

    a = tf.constant(5)
    b = tf.constant(6)
    sum_c = tf.add(a,b)
    
    #开启会话
    #传统的会话定义
    sess = tf.Session()
    sum_c_value = sess.run(sum_c)
    print("sum_c_value:\n",sum_c_value)
    sess.close()
    
    return None

if __name__ == "__main__":
    session_demo()

使用上下文管理器:

import tensorflow as tf 

def session_demo():

    a = tf.constant(5)
    b = tf.constant(6)
    sum_c = tf.add(a,b)
    
    #采用上下文管理器
    with tf.Session() as sess:
        sum_c_value = sess.run(sum_c)
        print("sum_c_value:\n",sum_c_value)
    
    return None

if __name__ == "__main__":
    session_demo()

两种方法输出结果相同,传统的会话定义方式需要自己调用close()方法来释放资源,如果我们过早的释放了资源,很可能会产生错误,我们一般将会话作用与上下文管理器,让其在合适的时候自动将其关闭。

————————————————————————————————————————————————————

两种启动方式说完了,来说下两者的区别:

tf.InteractiveSession()默认自己就是用户要操作的session;而tf.Session()没有这个默认,因此用eval()启动计算时需要指明session。

  • 会话的run()

run(fetches,feed_dict = None,options=None,run_metadata=None)

会话创建好之后,通过sess.run()来运行Operation

fetches:单一的Operation,或者列表、元组(其他不属于TensorFlow的类型不行)

feed_dict:参加允许调用这覆盖图中张量的值,运行时赋值,与tf.placeholder搭配使用,则会检查值的形状是否与占位符兼容

 

使用tf.operation.eval()也可以运行operation,但需要在会话中运行

fetches传入多个值时,这里以列表为例:

import tensorflow as tf 

def session_demo():

    a_t = tf.constant(5)
    b_t = tf.constant(6)
    c_t = tf.add(a_t,b_t)

    with tf.Session() as sess:

        #同时查看a_t,b_t,c_t
        abc = sess.run([a_t,b_t,c_t])
        print("abc_1:\n",abc)
        
        a,b,c= sess.run([a_t,b_t,c_t])
        print("abc_2:\n",a,b,c)
       
    return None

if __name__ == "__main__":
    session_demo()

输出结果:

abc_1:
 [5, 6, 11]
abc_2:
 5 6 11

  • feed操作:

有的时候,我们在定义张量的时候,我们并不确定具体的值时什么,就可以用placeholder去定义,相当于先占了一个位置

placeholder提供占位符,定义时必须要填类型,run时候通过feed_dict指定参数,

import tensorflow as tf 

def ph_demo():

    #定义占位符,此时我们并不知道a_ph,b_ph的值
    a_ph = tf.placeholder(tf.float32)
    b_ph = tf.placeholder(tf.float32)
    c_ph = tf.add(a_ph,b_ph)
    print('a_ph:\n',a_ph)
    print('b_ph:\n',b_ph)
    print('c_ph:\n',c_ph)

if __name__ == "__main__":
    ph_demo()

运行结果:

a_ph:
 Tensor("Placeholder:0", dtype=float32)
b_ph:
 Tensor("Placeholder_1:0", dtype=float32)
c_ph:
 Tensor("Add:0", dtype=float32)

如果运行时不指定feed_dict:

import tensorflow as tf 

def ph_demo():

    #定义占位符
    a_ph = tf.placeholder(tf.float32)
    b_ph = tf.placeholder(tf.float32)
    c_ph = tf.add(a_ph,b_ph)
    print('a_ph:\n',a_ph)
    print('b_ph:\n',b_ph)
    print('c_ph:\n',c_ph)


    with tf.Session() as sess:

        #运行placeholder
        c_ph_value = sess.run(c_ph)
        print("c_ph_value:\n",c_ph_value)


if __name__ == "__main__":
    ph_demo()

运行时会报错:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float
         [[{{node Placeholder_1}} = Placeholder[dtype=DT_FLOAT, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

运行时指定feed_dict,如:

c_ph_value = sess.run(c_ph,feed_dict={a_ph:3.0,b_ph:4.0})

运行结果:

a_ph:
 Tensor("Placeholder:0", dtype=float32)
b_ph:
 Tensor("Placeholder_1:0", dtype=float32)
c_ph:
 Tensor("Add:0", dtype=float32)
c_ph_value:
 7.0

通过feed_dict传一个字典,告诉会话a_ph,b_ph的值;当然,传的数据必须和所定义的类型一致

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值