使用 Eager Execution 编码并运行图表:以通过 RevNet 优化代码为例

本文通过可逆残差网络(RevNet)的实现,展示如何利用 TensorFlow 的 Eager Execution 从原型设计快速过渡到使用 tf.keras 和 tf.estimator 在 Cloud TPU 上进行高效训练。探讨了 Eager Execution 的优势,如简化内存管理,以及如何利用 tf.train.Checkpoint、tf.data.Dataset 和 TPUEstimator 进行模型存储、输入管道构建和分布式训练。
摘要由CSDN通过智能技术生成

文 / 软件工程实习生 Xuechen Li

来源: TensorFlow 公众号

Eager Execution 可简化 TensorFlow 中的模型构建体验,而 Graph Execution 可提供优化,以加快模型运行速度及提高存储效率。本篇博文展示了如何编写 TensorFlow 代码,以便将借助 tf.keras API 并使用 Eager Execution 构建的模型转换为图表,最终借助 tf.estimator API 的支持,在 Cloud TPU 上部署此模型。
注:tf.keras 链接
https://www.tensorflow.org/guide/keras
tf.estimator 链接
https://www.tensorflow.org/guide/estimators

我们使用可逆残差网络(RevNet、Gomez 等)作为示例。接下来的部分假设读者对卷积神经网络和 TensorFlow 有基本了解。您可以在此处找到本文的完整代码(为确保代码在所有设置中正常运行,强烈建议您使用 tf-nightly 或 tf-nightly-gpu)。

RevNet

RevNet 与残差网络(ResNet、He 等)类似,只不过他们是可逆的,在给定输出的情况下可重建中间计算。此技术的好处之一是我们可以通过重建激活来节省内存,而不是在训练期间将其全部存储在内存中(回想一下,由于链式法则有此要求,因此我们需要中间结果来计算有关输入的梯度)。相比传统架构上的一般反向传播,这使我们可以适应较大的批次大小,并可训练更具深度的模型。具体来说,此技术的实现方式是通过使用一组巧妙构建的方程来定义网络:

其中顶部和底部方程组分别定义正演计算和其反演计算。这里的 x1 和 x2 是输入(从整体输入 x 中拆分出来),y1 和 y2 是输出,F 和 G 是 ConvNet。这使我们能够在反向传播期间精准重建激活,如此一来,在训练期间便无需再存储这些数据。

使用 tf.keras.Model 定义正向和反向传递

假设我们使用 “ResidualInner” 类来实例化函数 F 和 G,我们可以通过子类化 tf.keras.Model 来定义可逆代码块,并通过替换上面的方程中所示的 call 方法来定义正向传递:

1    class Residual(tf.keras.Model):    
2        def __init__(self, filters):    
3            super(Residual, self).__init__()    
4            self.f = ResidualInner(filters=filters, strides=(1, 1))
5            self.g = ResidualInner(filters=filters, strides=(1, 1))
6
7        def call(self, x, training=True):    
8            x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
9            f_x2 = self.f(x2, training=training)    
10            y1 = f_x2 + x1    
11            g_y1 = self.g(y1, training=training)    
12            y2 = g_y1 + x2    
13            return tf.concat([y1, y2], axis=self.axis) 

这里的 training 参数用于确定批标准化的状态。启用 Eager Execution 后,批标准化的运行平均值会在 training=True 时自动更新。执行等效图时,我们需要使用 get_updates_for 方法手动获取批标准化更新。

要构建节省内存的反向传递,我们需要使用 tf.GradientTape 作为上下文管理器来跟踪梯度(仅在有需要时):
注:tf.GradientTape 链接
https://www.tensorflow.org/api_docs/pyt

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值