在TensorFlow 2中,默认情况下将打开急切执行功能。用户界面直观而灵活(运行一次性操作要容易得多且更快),但这可能会牺牲性能和可部署性。
您可以使用tf.function
程序制作图形。它是一种转换工具,可以从您的Python代码中创建与Python无关的数据流图。这将帮助您创建高性能和可移植的模型,并且需要使用SavedModel
。
本指南将帮助您概念化如何tf.function
在引擎盖下工作,以便您可以有效地使用它。
主要要点和建议是:
- 在紧急模式下进行调试,然后使用进行修饰
@tf.function
。 - 不要依赖于Python的副作用,例如对象突变或列表追加。
tf.function
与TensorFlow ops配合使用效果最佳;NumPy和Python调用将转换为常量。
设置
import tensorflow as tf
定义一个辅助函数来演示您可能遇到的错误类型:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
基本
用法
一个Function
定义(例如,通过应用@tf.function
装饰)就像是一个核心TensorFlow操作:您可以急切地执行它; 您可以计算渐变;等等。
@tf.function # The decorator converts `add` into a `Function`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
[2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
您可以Function
在其他Function
s中使用。
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
[3., 3.],
[3., 3.]], dtype=float32)>
Function
s可能比渴望的代码快,尤其是对于具有许多小操作的图形。但是对于具有一些昂贵操作(例如卷积)的图,您可能看不到太多的加速。
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.003693543000053978
Function conv: 0.004675635000012335
Note how there's not much difference in performance for convolutions
追踪
本节介绍了幕后Function
工作原理,包括将来可能会更改的实现细节。但是,一旦您了解了跟踪的原因和时间,tf.function
有效地使用它就容易多了!
什么是“追踪”?
AFunction
在TensorFlow Graph中运行程序。但是,atf.Graph
不能代表您在急切的TensorFlow程序中编写的所有内容。例如,Python支持多态,但是tf.Graph
要求其输入具有指定的数据类型和维。或者,您可以执行其他任务,例如读取命令行参数,引发错误或使用更复杂的Python对象。这些东西都不能在Windows中运行tf.Graph
。
Function
通过将您的代码分为两个阶段来弥合这种差距:
1)在第一阶段,称为“跟踪”,Function
创建一个新的tf.Graph
。Python代码正常运行,但是所有TensorFlow操作(如添加两个Tensor)都被延迟:它们被捕获,tf.Graph
并且不运行。
2)在第二阶段中,tf.Graph
运行包含第一阶段中延迟的所有内容的a 。此阶段比跟踪阶段快得多。
根据其输入,Function
调用它时将不会始终运行第一阶段。请参阅下面的“跟踪规则”以更好地了解它是如何进行确定的。跳过第一阶段,仅执行第二阶段,即可为您带来TensorFlow的高性能。
当Function
确定要跟踪时,跟踪阶段紧随其后的是第二阶段,因此调用Function
两者都可以创建并运行tf.Graph
。稍后,您将看到如何仅使用来运行跟踪阶段get_concrete_function
。
当我们将不同类型的参数传递给时Function
,两个阶段都将运行:
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)
请注意,如果您重复调用Function
具有相同参数类型的a,TensorFlow将跳过跟踪阶段并重用先前跟踪的图,因为生成的图将是相同的。
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
您可以pretty_printed_concrete_signatures()
用来查看所有可用的跟踪:
print(double.pretty_printed_concrete_signatures())
double(a)
Args:
a: float32 Tensor, shape=()
Returns:
float32 Tensor, shape=()
double(a)
Args:
a: string Tensor, shape=()
Returns:
string Tensor, shape=()
double(a)
Args:
a: int32 Tensor, shape=()
Returns:
int32 Tensor, shape=()
到目前为止,您已经看到tf.function
在TensorFlow的图跟踪逻辑上创建了一个缓存的动态调度层。要更具体地讲术语,请执行以下操作:
- A
tf.Graph
是TensorFlow计算的原始,与语言无关的可移植表示形式。 - A
ConcreteFunction
包装一个tf.Graph
。 - A
Function
管理ConcreteFunction
s的缓存,并为您的输入选择正确的一个。 tf.function
包装一个Python函数,返回一个Function
对象。- 跟踪创建一个
tf.Graph
并将其包装在中ConcreteFunction
,也称为跟踪。
追踪规则
AFunction
确定是否重用ConcreteFunction
通过从输入的args和kwargs计算缓存键来跟踪的跟踪。甲缓存键是标识密钥ConcreteFunction
基于输入指定参数和所述的kwargsFunction
呼叫,根据下面的规则(其可以改变):
- 为a生成的密钥
tf.Tensor
是其形状和dtype。 - 为a生成的密钥
tf.Variable
是唯一的变量ID。 - 对于一个Python生成的密钥原始(如
int
,float
,str
)是它的值。 - 为嵌套
dict
s,list
s,tuple
s,namedtuple
s和attr
s生成的键是叶子键的扁平元组(请参阅参考资料nest.flatten
)。(由于这种扁平化,以与在跟踪过程中使用的嵌套结构不同的嵌套结构调用具体函数将导致TypeError)。 - 对于所有其他Python类型,键都是基于对象的,
id()
因此可以为类的每个实例独立地跟踪方法。
注意:缓存键基于Function
输入参数,因此仅对全局变量和可用变量的更改将不会创建新的跟踪。有关处理Python全局变量和自由变量的建议做法,请参见本节。
控制追溯
当您Function
创建多个跟踪时,追溯将帮助确保TensorFlow为每组输入生成正确的图形。但是,跟踪是一项昂贵的操作!如果您Function
为每