当我们需要启用GPU对模型进行训练的时候,常常需要使用到一下代码对session进行设置。
with tf.Graph().as_default():
gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory_fraction)
sess=tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,log_device_placement=False))
with sess.as_default():
tf.graph
tensorflow的运算被表示为一个数据流图,其中包含运算结点(操作)和数据结点(数据对象)。在我们开始任务时,tensorflow会提供一个默认的graph,如果我们没有显式的定义一个graph,那么我们接下来的操作就会基于这个graph完成,一般会使用到as_default()
函数和with
关键字在代码块内完成对graph的操作。
tf.session
TensorFlow 使用 tf.Session 类来表示客户端程序(通常为 Python 程序,但也提供了其他语言的类似接口)与 C++ 运行时之间的连接。tf.Session 对象使我们能够访问本地机器中的设备和使用分布式 TensorFlow 运行时的远程设备。它还可缓存关于 tf.Graph 的信息,使我们能够多次高效地运行同一计算。
session可以通过函数进行参数设置,同时可以根据graph中的操作结点对数据结点进行运算。
使用as_default()
函数和with
关键字在代码块内完成对session的操作。
使用close()
关闭session,有关该session的资源都会被释放。
tf.ConfigProto
此函数用于对session的参数设置,如文章开头的代码段所示。
tf.GPUOptions:
在构造tf.Session()时可通过tf.GPUOptions作为可选配置参数的一部分来显示地指定需要分配的显存比例。
per_process_gpu_memory_fraction指定了每个GPU进程中使用显存的上限,但它只能均匀地作用于所有GPU,无法对不同GPU设置不同的上限。
tf.ConfigProto参数如下:
log_device_placement=True : 是否打印设备分配日志
allow_soft_placement=True : 如果你指定的设备不存在,允许TF自动分配设备
tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)