tf.function功能提升绘图性能

在TensorFlow 2中,tf.function可以将Python代码转换为图形以提高性能和可移植性。通过追踪和自动图转换,它可以管理Python副作用和控制流程。然而,依赖Python副作用、变量创建和全局变量的变动可能导致问题。建议在调试时使用急切执行,避免在循环中处理Python数据,使用tf.data处理数据,并避免在tf.function内部改变变量状态。
摘要由CSDN通过智能技术生成

在TensorFlow 2中,默认情况下将打开急切执行功能。用户界面直观而灵活(运行一次性操作要容易得多且更快),但这可能会牺牲性能和可部署性。

您可以使用tf.function程序制作图形。它是一种转换工具,可以从您的Python代码中创建与Python无关的数据流图。这将帮助您创建高性能和可移植的模型,并且需要使用SavedModel

本指南将帮助您概念化如何tf.function在引擎盖下工作,以便您可以有效地使用它。

主要要点和建议是:

  • 在紧急模式下进行调试,然后使用进行修饰@tf.function
  • 不要依赖于Python的副作用,例如对象突变或列表追加。
  • tf.function与TensorFlow ops配合使用效果最佳;NumPy和Python调用将转换为常量。

设置

 
import  tensorflow  as  tf

定义一个辅助函数来演示您可能遇到的错误类型:

 
import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

基本

用法

一个Function定义(例如,通过应用@tf.function装饰)就像是一个核心TensorFlow操作:您可以急切地执行它; 您可以计算渐变;等等。

 
@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
 
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
 
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
 
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

您可以Function在其他Functions中使用。

 
@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
 
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Functions可能比渴望的代码快,尤其是对于具有许多小操作的图形。但是对于具有一些昂贵操作(例如卷积)的图,您可能看不到太多的加速。

 
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
 
Eager conv: 0.003693543000053978
Function conv: 0.004675635000012335
Note how there's not much difference in performance for convolutions

追踪

本节介绍了幕后Function工作原理,包括将来可能会更改的实现细节。但是,一旦您了解了跟踪的原因和时间,tf.function有效地使用它就容易多了!

什么是“追踪”?

AFunctionTensorFlow Graph中运行程序。但是,atf.Graph不能代表您在急切的TensorFlow程序中编写的所有内容。例如,Python支持多态,但是tf.Graph要求其输入具有指定的数据类型和维。或者,您可以执行其他任务,例如读取命令行参数,引发错误或使用更复杂的Python对象。这些东西都不能在Windows中运行tf.Graph

Function 通过将您的代码分为两个阶段来弥合这种差距:

1)在第一阶段,称为“跟踪”,Function创建一个新的tf.Graph。Python代码正常运行,但是所有TensorFlow操作(如添加两个Tensor)都被延迟:它们被捕获,tf.Graph并且不运行。

2)在第二阶段中,tf.Graph运行包含第一阶段中延迟的所有内容的a 。此阶段比跟踪阶段快得多。

根据其输入,Function调用它时将不会始终运行第一阶段。请参阅下面的“跟踪规则”以更好地了解它是如何进行确定的。跳过第一阶段,仅执行第二阶段,即可为您带来TensorFlow的高性能。

Function确定要跟踪时,跟踪阶段紧随其后的是第二阶段,因此调用Function两者都可以创建并运行tf.Graph。稍后,您将看到如何仅使用来运行跟踪阶段get_concrete_function

当我们将不同类型的参数传递给时Function,两个阶段都将运行:

 
@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
 
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

请注意,如果您重复调用Function具有相同参数类型的a,TensorFlow将跳过跟踪阶段并重用先前跟踪的图,因为生成的图将是相同的。

 
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
 
tf.Tensor(b'bb', shape=(), dtype=string)

您可以pretty_printed_concrete_signatures()用来查看所有可用的跟踪:

 
print(double.pretty_printed_concrete_signatures())
 
double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

到目前为止,您已经看到tf.function在TensorFlow的图跟踪逻辑上创建了一个缓存的动态调度层。要更具体地讲术语,请执行以下操作:

  • Atf.Graph是TensorFlow计算的原始,与语言无关的可移植表示形式。
  • AConcreteFunction包装一个tf.Graph
  • AFunction管理ConcreteFunctions的缓存,并为您的输入选择正确的一个。
  • tf.function包装一个Python函数,返回一个Function对象。
  • 跟踪创建一个tf.Graph并将其包装在中ConcreteFunction,也称为跟踪。

追踪规则

AFunction确定是否重用ConcreteFunction通过从输入的args和kwargs计算缓存键来跟踪的跟踪。甲缓存键是标识密钥ConcreteFunction基于输入指定参数和所述的kwargsFunction呼叫,根据下面的规则(其可以改变):

  • 为a生成的密钥tf.Tensor是其形状和dtype。
  • 为a生成的密钥tf.Variable是唯一的变量ID。
  • 对于一个Python生成的密钥原始(如intfloatstr)是它的值。
  • 为嵌套dicts,lists,tuples,namedtuples和attrs生成的键是叶子键的扁平元组(请参阅参考资料nest.flatten)。(由于这种扁平化,以与在跟踪过程中使用的嵌套结构不同的嵌套结构调用具体函数将导致TypeError)。
  • 对于所有其他Python类型,键都是基于对象的,id()因此可以为类的每个实例独立地跟踪方法。

注意:缓存键基于Function输入参数,因此仅对全局变量可用变量的更改将不会创建新的跟踪。有关处理Python全局变量和自由变量的建议做法,请参见本节

控制追溯

当您Function创建多个跟踪时,追溯将帮助确保TensorFlow为每组输入生成正确的图形。但是,跟踪是一项昂贵的操作!如果您Function为每

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

歇歇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值