TF2.0默认为动态图,即eager模式。意味着TF能像Pytorch一样不用在session中才能输出中间参数值了,那么动态图和静态图毕竟是有区别的,tf2.0也会有写法上的变化。不过值得吐槽的是,tf2.0启动速度仍然比Pytorch慢的多。
操作被记录在磁带中(tape)
这是一个关键的变化。在TF0.x到TF1.X时代,操作(operation)被加入到Graph中。但现在,操作会被梯度带记录,我们要做的仅仅是让前向传播和计算损失的过程发生在梯度带的上下文管理器中。
with tf.GradientTape() as tape:
logits = mnist_model(images, training=True)
loss_value = tf.losses.sparse_softmax_cross_entropy(labels, logits)
# loss_value 必须在tape内部
grads = tape.gradient(loss_value, mnist_model.variables)
optimizer.apply_gradients(zip(grads, mnist_model.variables),
global_step=tf.train.get_or_create_global_step())
注意到这里的tape.gradient用来计算损失函数和model参数的倒数。我们在之前的版本要吗使用优化器的minimize功能,要吗使用tf.gradients来计算导数。在eager模式,tf.gradients不能使用
RuntimeError: tf.gradients is not supported when eager execution is enabled. Use tf.GradientTape instead.
因此,tf2.0的训练风格就如上所示。总结一下:
- 新建一个tf.GradientTape的上下文管理器
- 在管理器内部构建模型的前向传播和计算损失函数。
- 计算每个训练参数(trainable)的导数
- 优化器用计算得到的导数对模型参数更新。
虽然我们明白了在tf2.0中训练模型的新风格,但不妨深入了解一下磁带机制。
深入理解tape
我们随便写一个操作。
import tensorflow as tf
# tf.enable_eager_execution() 1.12往后的tf1.x需要加入这一条代码
with tf.GradientTape() as tape:
x = tf.Variable(tf.ones((2,2)))
y = tf.Variable([[2,2],[2,2]],dtype=tf.float32)
# y = x+1
z = tf.matmul(x,y)
loss = tf.reduce_mean(z)
grads = tape.gradient(loss,x)
print(grads)
>>> tf.Tensor(
[[1. 1.]
[1. 1.]], shape=(2, 2), dtype=float32)
同时tape.gradient还能一次对多个变量求导
import tensorflow as tf
tf.enable_eager_execution()
with tf.GradientTape() as tape:
x = tf.Variable(tf.ones((2,2)))
y = tf.Variable([[2,2],[2,2]],dtype=tf.float32)
# y = x+1
z = tf.matmul(x,y)
loss = tf.reduce_mean(z)
grads = tape.gradient(loss,[x,y]) # 用[ ]装载需要求导的变量
print(grads)
>>> [<tf.Tensor: id=32, shape=(2, 2), dtype=float32, numpy=
array([[1., 1.],
[1., 1.]], dtype=float32)>, <tf.Tensor: id=33, shape=(2, 2), dtype=float32, numpy=
array([[0.5, 0.5],
[0.5, 0.5]], dtype=float32)>]
GradientTape的persistent可选参数
我们知道在Pytorch中,loss.backward()有一个参数为retain_graph,用来防止一次backward就将图释放了。在TF2.0中,同样有一个类似的功能。
def init(self, persistent=False, watch_accessed_variables=True):
上面是GradientTape的构造函数。那个persistent就是用来让tape记录的操作不会因为一次tape.gradient()调用就释放。
当我们想释放这一次的前向传播过程记录时,只需要将tape删除,实现python的垃圾回收机制即可。
del tape
动态图和静态图的区别
使用 Graph Execution 时,程序状态(如变量)存储在全局集合中,它们的生命周期由 tf.Session 对象管理。相反,在 Eager Execution 期间,状态对象的生命周期由其对应的 Python 对象的生命周期决定。