可以复制粘贴直接使用的代码
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
当然,如果你还想要限制程序运行在某一个特定的显卡上,复制以下代码加入即可
import os
os.environ["CDUA_DEVICE_ORDER"] = "PCI_BUS_ID" #使显卡编号和你用`watch gpustat`命令看到的显卡编号一致
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" #使用你想用的哪几张显卡