1.py_function函数
py_function是tensorflow 2.x系列中的API。
@tf_export("py_function")
@dispatch.add_dispatch_support
def eager_py_func(func, inp, Tout, name=None):
"""Wraps a python function into a TensorFlow op that executes it eagerly.
This function allows expressing computations in a TensorFlow graph as
Python functions. In particular, it wraps a Python function `func`
in a once-differentiable TensorFlow operation that executes it with eager
execution enabled. As a consequence, `tf.py_function` makes it
possible to express control flow using Python constructs (`if`, `while`,
`for`, etc.), instead of TensorFlow control flow constructs (`tf.cond`,
`tf.while_loop`). For example, you might use `tf.py_function` to
implement the log huber function:
根据上面的解释,可以看出py_function的作用为:
将一个python原生函数包装成一个tf的operation操作,这样可以在函数表达式中比较简单地加入各种python控制模块,比如if, while, for等等操作。
什么情况下将会使用到py_function操作呢?当我们有一些操作,tf中没有内置,需要我们自己编写相应运行逻辑的时候,这个时候就用到了py_function。
比如以源码中提到的grad_huber为例:
def log_huber(x, m):
if tf.abs(x) <= m:
return x**2
else:
return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))
然后我们想对x求导:
import tensorflow as tf
def grad_huber():
x = tf.Variable(initial_value=1.5)
m = tf.Variable(initial_value=2.0)
with tf.GradientTape(persistent=True) as g:
g.watch(x)
y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
print(y)
dy_dx = g.gradient(y, x)
print(dy_dx)
对于该分段函数,此时x=1.5, m=2.0,x<m,因此此时函数为y=x**2。y的值为2.25, y的梯度为2x=3.0。
代码的输出为:
tf.Tensor(2.25, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
2.GradientTape中的watch函数
GradientTape默认只监控tf.Variable中创建的,且trainable=True的变量。如果想手动指定监控的变量,可以使用watch函数。
def grad_v2():
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
with tf.GradientTape() as gg:
gg.watch(x)
y = x**2
dy_dx = gg.gradient(y, x)
d2y_dx2 = g.gradient(dy_dx, x)
print("dy_dx is: ", dy_dx)
print("d2y_dx2 is: ", d2y_dx2)
上面的代码中,y = x**2。我们想求在x=3处,y的一阶导与二阶导。因为x是constant,所以需要先watch先加入以便自动计算梯度。
dy_dx is: tf.Tensor(6.0, shape=(), dtype=float32)
d2y_dx2 is: tf.Tensor(2.0, shape=(), dtype=float32)
3.GradientTape中persistent属性
GradientTape默认的是在调用gradient方法以后,资源就被释放,无法进行下一次梯度计算。看如下例子:
def grad_persistent():
x = tf.constant(3.0)
with tf.GradientTape() as tape:
tape.watch(x)
y = x**2
z = x**3
dy_dx = tape.gradient(y, x)
dz_dx = tape.gradient(z, x)
print("dy_dx is: ", dy_dx)
print("dz_dx is: ", dz_dx)
grad_persistent()
上述代码直接运行的话,会爆出如下错误:
RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes.
此时需要将GradientTape相关代码修改如下:
with tf.GradientTape(persistent=True) as tape:
再运行上面代码,就可以得到正确输出:
dy_dx is: tf.Tensor(6.0, shape=(), dtype=float32)
dz_dx is: tf.Tensor(27.0, shape=(), dtype=float32)