在默认情况下,tensorflow为避免显存碎片化,会使用所有显存投入到计算。在这种情况,容易出现Blas GEMM launch failed错误。因此,需要对tensorflow的显存使用进行限制。在tensorflow官网中,给出了两种方法:
第一种方法是随调随用,即在初始的时候只使用少量显存,然后在计算过程中逐步增加显存。主要调用函数是tf.config.experimental.set_memory_growth,代码如下:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for