Tensorflow笔记2:Session

参考内容都出自于官方API教程tf.Session

一、Session类基本使用方法

  这里使用的是1.15版本,TF官方为了能够在2.0+版本中保持兼容,因此调用时使用了tf.compat.v1.Session
  定义:一个Session是对环境的封装,环境中包含执行/executed过的Operation和评估/evaluated过的Tensor
  一个Session会含有很多资源,例如Variable、QueueBase、RenderBase等。当Session运行结束后需要通过**Session().close()**方法释放资源。

# Build a graph.
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
sess = tf.Session()	# Launch the graph in a session.
print(sess.run(c))	# Evaluate the tensor `c`.
sess.close()	# 释放资源

# context manager形式,with结束自动关闭sess
with tf.Session() as sess:
	sess.run(...)

  Session在创建时,构造函数__init__可以指定三个参数:

  • target:一般用于分布式TF中,用于连接执行引擎;
  • graph:此Session要launch的图,不指定则为默认;
  • config:是一个protocol buffer:ConfigProto
# Launch the graph in a session that allows soft device placement and logs the placement decisions. 官方实例
sess = tf.Session(config=tf.ConfigProto(
    allow_soft_placement=True,	# 对设备采用柔性约束
    log_device_placement=True)	# 记录信息
    )
    
# 设定GPU增量使用模式,不要全占满GPU:
tfconfig = tf.ConfigProto(allow_soft_placement=True)  # set device auto
tfconfig.gpu_options.allow_growth = True  # mem increase
with tf.Session(config=tfconfig) as sess:

二、Properties

  Session中有很多properties,这些可以直接通过Session().调用查看内部的值。此处介绍Session中关于graph的两个properties:

  • graph:返回此Session中已经launch的graph。
  • graph_def:将构建的TF图以串行化方式显示出来。
a = tf.constant(1.0)
sess = tf.Session()
assert sess.graph == tf.get_default_graph()
print(sess.graph)   # <tensorflow.python.framework.ops.Graph object at 0x7f4a672ac9d0>
print(sess.graph_def)	# 返回图的串行化表示:
# node {
#   name: "Const"
#   op: "Const"
# 	...
#  versions {
#   producer: 22
# }

二、Methods

  并没有全列出来,见一个记录一个:

1.as_default()

  返回一个context manager将当前Session设置为默认。一般结合with语句进行使用。as_default()语句并不会结束Session,必须调用close()方法手动结束。

c = tf.constant(1.0)
tfconfig = tf.ConfigProto(allow_soft_placement=True)  # set device auto
tfconfig.gpu_options.allow_growth = True  # mem increase
sess = tf.Session(config=tfconfig)
with sess.as_default():
    assert tf.get_default_session() is sess
    print tf.get_default_session()  # <tensorflow.python.client.session.Session object at 0x7f314d0b3a50>
    print(c.eval()) # 1.0 张量使用eval方法进行求值
    print(sess.run(c)) # 1.0 也可以使用run进行求值
    print(c)  # Tensor("Const:0", shape=(), dtype=float32)

  在会话中进行运算并取值一共有三种:

  • tf.Operation.run():针对操作,Run operations
  • tf.Tensor.eval():针对张量,Evaluate tensors
  • sess.run():在全局空间取值,下部分会讲

2.run()

  用于运行Operations和评估Tensors。

run(
    fetches,	# 要运行的Op和要评估的Tensor,列表传入
    feed_dict=None,	# 投喂的数据
    options=None,
    run_metadata=None
)

三、应用实例:多Graph多Session

import tensorflow as tf
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # set default GPU:0
tfconfig = tf.ConfigProto(allow_soft_placement=True)  # set device auto
tfconfig.gpu_options.allow_growth = True  # mem increase

g1 = tf.Graph()  # Graph 1: actor net
g_ = tf.Graph()  # L_Net
g2 = tf.Graph()  # Graph 2: disturber net

with g1.as_default():
    a = tf.constant(2.0)
    b = tf.constant(3.0)
    c = tf.constant(4.0)
    d = tf.multiply(a, b) + c

with g2.as_default():
    e = tf.constant(4.0)
    f = tf.constant(6.0)
    g = tf.constant(6.0)
    h = tf.multiply(e, f) + g

with g_.as_default():
    i = tf.placeholder(dtype=tf.float32, shape=[], name='G_Input')
    j = i + 100.0

sess1 = tf.Session(graph=g1, config=tfconfig)
sess_ = tf.Session(graph=g_, config=tfconfig)
sess2 = tf.Session(graph=g2, config=tfconfig)

print('result of Net1: ', sess1.run(d))  # 10.0
print('result of Net2: ', sess2.run(h))  # 30.0
print(sess_.run(fetches=j, feed_dict={i: sess1.run(d)}))  # Take the output of Net1 as the input of L_Net
print(sess_.run(fetches=j, feed_dict={i: sess2.run(h)}))  # Take the output of Net2 as the input of L_Net
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值