深度框架Tensorflow系列之(二)OP开发

深度框架Tensorflow系列之(二)OP开发

上篇文章《深度框架Tensorflow系列之(一)开发环境部署》已经介绍了Tensorflow的安装部署,接下来是时候深入理解下Tensorflow的相关的技术点了,不过在这之前笔者先写了一篇关于Tensorflow OP编写方面的文字给大家预预热,OP大家可以理解为算子,我们在写模型代码的时候使用的注入add等大家都可以理解为算子(或者理解为函数就好)。

1 背景

如果你想要创建一个在TensorFlow 库中不存在的操作,我们建议你先从 Python 入手,即写一个已有 Python 操作或函数的复合操作。 如果这样不可行,你可以定制一个 C++ 操作。下面是你可能需要这样做的一些理由:

  1. 将你的操作表示成现有操作的组合不太容易或不可能。
  2. 已有基本操作的组合操作效率不高。
  3. 你想手工实现一些基本操作的组合,因为未来的编译器做这种融合可能会比较困难。

例如,假设您想要实现诸如“中值池化”之类的功能,与“MaxPool”算子类似,但需要计算滑动窗口期间的中值而不是最大值。可以使用运算组合来实现这一目的(例如,使用 ExtractImagePatches 和 TopK),但在性能或内存效率方面可能不如原生运算那样出色,对于原生运算,您可以利用单个融合运算实现更巧妙的过程。和往常一样,通常有必要首先尝试使用算子组合来表示您想要的运算,只有在这被证实难以实现或效率低下时,才选择添加新运算。

2 OP开发步骤

为了实现自定义OP的开发,你可能需要如下的4个步骤,其中1和2是必选项,3和4是可选项,大家可以根据实际情况进行设计与开发:

  1. OP注册:在 C++ 文件中注册这个新操作。操作的注册为此操作的功能定义了一个接口(规范)。比如,操作的注册定义了此操作的名称和它的输入输出。它还定义了 shape 函数,用于获取张量的形状。
  2. OP实现:使用 C++ 实现运算。运算的实现称为内核,它是您在第 1 步中注册的规范的具体实现。可以有多个内核用于不同的输入/输出类型或架构(例如,CPU、GPU)。
  3. Python包装:创建一个 Python 包装器(可选)。这个包装器是用于在 Python 中创建操作的公共 API。操作的注册可以产生一个默认的包装器,它可以直接使用,或添加。
  4. 梯度计算:为操作编写一个函数来计算梯度(可选)。
  5. 测试运算:为方便起见,我们通常在 Python 中进行测试,但您也可以在 C++ 中测试运算。如果您要定义梯度,可以使用 Python tf.test.compute_gradient_error 验证梯度。要了解如何测试 ReLu 之类的算子及其梯度的前向函数,请参阅 relu_op_test.py

OP开发实例

下面通过一个简单的示例来介绍下,如何进行一个C++ OP的开发,代码相对来说比较简单,主要涉及上面OP开发步骤的1和2,详情如下:

  1. 程序路径:tensorflow/tensorflow/core/user_ops,大家还记得上篇文章编译过Tensorflow1.15吧,我直接将代码放在了编译好的Tensorflow路径下。
  2. 程序代码

code/ml/tensorflow-op/simple_user_op at master · dubaokun/code · GitHub

  1. 代码示例

  2

  3 Licensed under the Apache License, Version 2.0 (the "License");

  4 you may not use this file except in compliance with the License.

  5 You may obtain a copy of the License at

  6

  7     http://www.apache.org/licenses/LICENSE-2.0

  8

  9 Unless required by applicable law or agreed to in writing, software

 10 distributed under the License is distributed on an "AS IS" BASIS,

 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 12 See the License for the specific language governing permissions and

 13 limitations under the License.

 14 ===========================================================================*/

 15

 16 // An example Op.

 17

 18 #include "tensorflow/core/framework/common_shape_fns.h"

 19 #include "tensorflow/core/framework/op.h"

 20 #include "tensorflow/core/framework/op_kernel.h"

 21

 22 // OP Register

 23 REGISTER_OP("FactTest")

 24     .Output("fact: string")

 25     .SetShapeFn(tensorflow::shape_inference::UnknownShape);

 26

 27 // OP Operation

 28 class FactTestOp : public tensorflow::OpKernel {

 29  public:

 30   explicit FactTestOp(tensorflow::OpKernelConstruction* context)

 31       : OpKernel(context) {}

 32

 33   void Compute(tensorflow::OpKernelContext* context) override {

 34     // Output a scalar string.

 35     tensorflow::Tensor* output_tensor = nullptr;

 36     OP_REQUIRES_OK(context, context->allocate_output(

 37                                 0, tensorflow::TensorShape(), &output_tensor));

 38     using tensorflow::string;

 39     auto output = output_tensor->template scalar<tensorflow::tstring>();

 40

 41     output() = "0! == 1";

 42   }

 43 };

 44

 45 // OP Device Bind

 46 REGISTER_KERNEL_BUILDER(Name("FactTest").Device(tensorflow::DEVICE_CPU), FactTestOp);

于是,我们注册了一个名为 FactTest的操作,并且在里面进行了简单的实现。然后bind了OP在CPU上运行。

直接运行run.sh就可以进行编译了,然后可以在bazel-bin目录找打fact_test.so。ResourceMgr来跟踪操作的状态。

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值