一、环境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
cudnn64_7.dll
Python 3.6.3
Windows 10
二、官方说明
https://www.tensorflow.org/api_docs/python/tf/device
使用默认图(default graph)是Graph.device()的包装器(wrapper)
详细参照tf.Graph.device
输入:
在上下文中使用的设备名称或者函数
输出:
为新创建的操作指定默认设备的上下文管理器
三:实例
(1)使用CPU
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [vocab_size, size], dtype=tf.float32)
inputs = tf.nn.embedding_lookup(embedding, input_.input_data)
(2)使用GPU
with tf.device("/gpu:0"):
embedding = tf.get_variable("embedding", [vocab_size, size], dtype=tf.float32)
inputs = tf.nn.embedding_lookup(embedding, input_.input_data)