1. 检测gpu
pip install tensorflow-gpu==2.5
import tensorflow as tf
# 打印TensorFlow是否使用GPU加速
print(tf.test.gpu_device_name())
2.版本问题
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
3.Keras问题
一般不升级Keras会有问题产生,需要将Keras升级到Tensorflow同样的版本。另外有些函数需要进行改变。
如:修改前(此图是安装完2.x高版本的tensorflow-gpu后,改回原来的代码,SGD报错原因——keras在tensorflow里面,低版本的keras与高版本的tensorflow不兼容,产生报错)
安装tensorflow-gpu,修改后:
4.程序中调用GPU
注意:原本代码不含蓝色一块,由于调用gpu的原因添加,然后可成功调用服务器的gpu(附代码)
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)