一、原因分析
在tensorflow 2.0
里面,要想一个高阶迭代多次调用tf.GradientTape()
时报错,因为tape
是一次性的,算完就会释放,所以要想重复调用必须设置persistent=’True‘
,但是注意如果忘记了释放就会导致GPU被占用
w = tf.constant(1.)
x = tf.constant(2.)
with tf.GradientTape(persistent='True') as tape:
#tape是一次性的算完就会释放,所以要想重复调用,设置persistent=’True‘,但是记得释放因为很占内存
tape.watch([w])#跟踪参数的梯度,必须要,不然就会出现None的情况
y = x*w
二、报错:
tensorflow.python.framework.errors_impl.InternalError: Blas GEMM launch failed :
三、解决办法
关闭GradientTape
的内存占用