TensorFlow2.0教程-使用tf.function和AutoGraph提高代码性能

TensorFlow2.0教程-使用tf.function和AutoGraph提高代码性能

原文地址:https://doit-space.blog.csdn.net/article/details/95041177

最全Tensorflow 2.0 入门教程持续更新:https://blog.csdn.net/qq_31456593/article/details/88606284

完整tensorflow2.0教程代码请看 https://github.com/czy36mengfei/tensorflow2_tutorials_chinese (欢迎star)

本教程主要由tensorflow2.0官方教程的个人学习复现笔记整理而来,中文讲解,方便喜欢阅读中文教程的朋友,官方教程:https://www.tensorflow.org

在TensorFlow 2.0中,默认情况下启用了急切执行。 对于用户而言直观且灵活(运行一次性操作更容易,更快),但这可能会牺牲性能和可部署性。

要获得最佳性能并使模型可在任何地方部署,可以优先使用tf.function从程序中构建图。 因为有AutoGraph,可以使用tf.function构建高效性能的Python代码,但仍有一些陷阱需要警惕。

今天我们就来介绍一下tensorflow2.0中的TF fuction和AutoGraph。

下面的辅助程序代码,用于演示可能遇到的各种错误。

import contextlib

# 构建包含上下文管理器的函数,使其可以在with中使用
@contextlib.contextmanager
def assert_raises(error_class):
    try:
        yield
    except error_class as e:
        print('Caught expected exception \n  {}: {}'.format(error_class, e))
    except Exception as e:
        print('Got unexpected exception \n  {}: {}'.format(type(e), e))
    else:
        raise Exception('Expected {} to be raised but no error was raised!'.format(
            error_class))

tf.function

一个tf.function定义就像是一个核心TensorFlow操作:可以急切地执行它; 也可以在静态图中使用它; 且它具有梯度。

# 类似一个tensorflow操作
@tf.function
def add(a, b):
    return a+b

add(tf.ones([2,2]), tf.ones([2,2]))
<tf.Tensor: id=14, shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
# tf.function操作可以计算梯度
@tf.function
def add(a, b):
    return a+b
v = tf.Variable(2.0)
with tf.GradientTape() as tape:
    res = add(v, 1.0)

tape.gradient(res, v) 
<tf.Tensor: id=40, shape=(), dtype=float32, numpy=1.0>
# 可以内嵌调用tf.function
@tf.function
def dense_layer(x, w, b):
    return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: id=67, shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

跟踪和多态

Python的动态类型意味着可以使用各种参数类型调用函数,Python将在每个场景中执行不同的操作。

另一方面,TensorFlow图需要静态dtypes和形状尺寸。tf.function通过在必要时回溯函数来生成正确的图结构来弥补这一差距。大多数使用的tf.function源于这种回归行为。

我们可以使用不同类型的参数调用函数来查看正在发生的事情。

# 函数的多态
@tf.function
def double(a):
    print('追踪变量:',a)
    return a + a

print('结果:',double(tf.constant(1)))
print()
print('结果:',double(tf.constant(1.1)))
print()
print('结果:',double(tf.constant('c')))
print()
追踪变量: Tensor("a:0", shape=(), dtype=int32)
结果: tf.Tensor(2, shape=(), dtype=int32)

追踪变量: Tensor("a:0", shape=(), dtype=float32)
结果: tf.Tensor(2.2, shape=(), dtype=float32)

追踪变量: Tensor("a:0", shape=(), dtype=string)
结果: tf.Tensor(b'cc', shape=(), dtype=string)

控制参数类型:
创建一个新的tf.function。tf.function确保单独的对象不共享追踪。
使用该get_concrete_function方法获取特定追踪
指定input_signature何时调用tf.function以确保仅构建一个功能图。

print('构建许可的追踪')
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("执行追踪函数")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("使用不合法参数")
with assert_raises(tf.errors.InvalidArgumentError):
    double_strings(tf.constant(1))
构建许可的追踪
追踪变量: Tensor("a:0", dtype=string)
执行追踪函数
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
使用不合法参数
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute __inference_double_98 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_98]
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
    print("Tracing with", x)
    return tf.where(tf.equal(x % 2, 0), x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# 只能输入1维向量
with assert_raises(ValueError):
    next_collatz(tf.constant([[1, 2], [3, 4]]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>: Python inputs incompatible with input_signature: inputs ((<tf.Tensor: id=125, shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4]], dtype=int32)>,)), input_signature ((TensorSpec(shape=(None,), dtype=tf.int32, name=None),))

什么时候回溯?

多态tf.function通过跟踪生成具体函数的缓存。缓存键实际上是从函数args和kwargs生成的键的元组。为tf.Tensor参数生成的关键是其形状和类型。为Python原语生成的密钥是它的值。对于所有其他Python类型,键都基于对象,id()以便为每个类的实例独立跟踪方法。将来,TensorFlow可以为Python对象添加更复杂的缓存,可以安全地转换为张量。

使用Python参数还是Tensor参数?

通常,Python的参数被用来控制超参数和图的结构-例如,num_layers=10或training=True或nonlinearity=‘relu’。因此,如果Python参数发生变化,那么必须回溯图。

但是,Python参数可能不会用于控制图构造。在这些情况下,Python值的变化可能会触发不必要的回溯。举例来说,这个训练循环,AutoGraph将动态展开。尽管存在多条迹线,但生成的图实际上是相同的,因此这有点低效。

def train_one_step():
    pass

@tf.function
def train(num_steps):
    print(&
  • 5
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值