如何为 Cloud TPU 编写自定义估算器模型

文 / Google Cloud Platform 技术主管 Lak Lakshmanan (@lak_gcp)

来源 | TensorFlow 公众号

张量处理单元 (TPU) 可加速处理 Google 内各种机器学习工作负载,并可供 Google Cloud 客户使用。您可以在 Cloud TPU 参考模型存储区找到启用 TPU 的顶尖图像模型版本,例如 ResNet 和 AmoebaNet;您还可以使用强大的 Tensor2Tensor 库,在 TPU 上执行文本摘要和问答任务。这些教程会为您分步介绍如何使用当下很多最热门的 Cloud TPU 参考模型。

注:存储区链接
https://github.com/tensorflow/tpu
教程链接
https://cloud.google.com/tpu/docs/tutorials

在这里插入图片描述
但如果您拥有自定义 TensorFlow 模型,又该如何做呢?在本文中,我会逐步介绍编写自定义估算器以便在 Cloud TPU 上运行的全过程。在此过程中,我会指出需要注意的地方和建议采用的最佳实践。您可以在 GitHub 上找到此解决方案的完整代码;本文仅列出相关代码片段。

注:解决方案的完整代码链接
https://github.com/GoogleCloudPlatform/training-data-analyst/tree/master/courses/machine_learning/deepdive/08_image/flowersmodeltpu

自定义 TensorFlow 估算器包含以模型函数传递的基类估算器:

  def train_and_evaluate(output_dir, nsteps):    
         estimator = tf.estimator.Estimator(    
                        model_fn = model_fn,    
                        model_dir = output_dir) 

模型函数会接收特征、标签和模式,并返回 EstimatorSpec。例如,图像分类问题的模型函数可能包含

def model_fn(features, labels, mode):
    # write the model to compute predictions, loss, etc. from the model

    return tf.estimator.EstimatorSpec(
                mode=mode,    
                predictions={"probabilities": probabilities, 
                                     "classid": class_int, "class": class_str},
                loss=loss,
                train_op=train_op, 
              eval_metric_ops=evalmetrics,
              export_outputs={'classes': tf.estimator.export.PredictOutput(
                        {"probabilities": probabilities, "classid": class_int, 
                         "class": class_str})}
        )

TensorFlow 中的 tf.contrib.tpu 包提供了包装器类,可助您以适当方式编写代码,以便在 CPU、GPU 和 Cloud TPU 上运行代码。下面我们就来看看如何以这种不受加速器限制的方式编写自定义估计器。

1.将输入数据转换为 TF 记录

Cloud TPU 的速度非常快,一不小心,您的训练就会变成以读写(“馈入” 和 “馈出”)数据和存储检查点为主。让 TPU 等待输入/输出会造成浪费,所以我们会做几件事以充分利用 TPU 用于计算的时间。

首先是避免在估算器的输入函数中进行数据解析和整理,而是预先将数据转换为 TF 记录。与单个图像文件相比,批量处理 TF 记录更为简单,因为记录本身包含标签,如此可以减少系统必须读取的小型文件数量。我使用 Apache Beam 进行这种转换。您可以在官方 TPU 存储区找到读取 JPEG 和编写 TF 记录的脚本。您可以在 Cloud Dataflow 上大规模地执行 Apache Beam 程序,但如果您的数据源目前不在 Google Cloud 上,则只能在大型 VM 上本地执行此程序(请务必用 pip 安装 apache-beam)。

注:JPEG 和编写 TF 记录链接
https://github.com/tensorflow/tpu/blob/master/tools/datasets/jpeg_to_tf_record.py

TF 记录是字典。对于图像分类,上述管道编写的两个条目很重要,分别是:“image/class/label”(采用 int64)和 “image/encoded”(由 JPEG 文件内容组成)。

2.编写输入函数以读取 TF 记录

与任何估算器一样,您需要编写输入函数,以读取这些 TF 记录。使用 Dataset API 可极大地简化此任务,但还需注意几个问题。在讲解过程中,我会指出这些问题。

以下是我的输入函数:

def make_input_fn(pattern, mode, num_cores=8, transpose_input=False):
    def _set_shapes(batch_size, images, labels):
        """Statically set the batch_size dimension."""
            i
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值