TensorFlow2.0 中 GradientTape()函数详解
一、函数
tf.GradientTape(
persistent=False, watch_accessed_variables=True
)
二、作用
tensorflow
提供tf.GradientTape
api来实现自动求导功能。只要在tf.GradientTape()
上下文中执行的操作,都会被记录与“tape”中,然后tensorflow
使用反向自动微分来计算相关操作的梯度。
可训练变量(由tf.Variable
或创建tf.compat.v1.get_variable
,trainable=True
在两种情况下均为默认值)将被自动监视。通过watch
在此上下文管理器上调用方法,可以手动监视张量。
三、参数
persistent
:布尔值,用于控制是否创建持久渐变磁带。默认情况下为False,这意味着最多可以在此对象上对gradient()方法进行一次调用。watch_accessed_variables
:布尔值,控制watch
在磁带处于活动状态时磁带是否将自动访问任何(可训练的)变量。默认值为True,可以从磁带中读取可训练的磁带得出的任何结果中请求梯度Variable
。如果为False,则用户必须明确要求他们要从中请求渐变的