自定义Op编译进Tensorflow并使用Py和C++接口调用

77 篇文章 21 订阅
54 篇文章 0 订阅

本教程是使用Bazel把自定义op编译进入TensorFlow并用C++Python调用,使之在全局生效;不是像之前的加载.so文件的方式调用局部生效。

克隆最新版Tensorflow,包括依赖:

git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git

自定义运算的两个cppmy_add.cczero_out.cc,放置于tensorflow/core/user_ops/目录下。

my_add.cc是计算两个Tensor的和,只是把第一个元素设置为0zero_out.cc是输出一个Tensor的副本,唯一的区别在于第一个元素被置为 0

//my_add.cc
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("MyAdd")
    .Input("x: int32")
    .Input("y: int32")
    .Output("z: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      c->set_output(0, c->input(1));
      return Status::OK();
    });


#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class MyAddOp : public OpKernel {
 public:
  explicit MyAddOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& a = context->input(0);
    const Tensor& b = context->input(1);
    auto A = a.flat<int32>();
    auto B = b.flat<int32>();
    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = A.size();

    for (int i = 1; i < N; i++) {
      output_flat(i) = A(i)+B(i);
    }
    output_flat(0) = 0;
  }
};


REGISTER_KERNEL_BUILDER(Name("MyAdd").Device(DEVICE_CPU), MyAddOp);
////zero_out.cc
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"


using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });


class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};


REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

之后需要安装Bazel解决依赖,详见[Tensorflow源码安装教程]。(https://blog.csdn.net/a446712385/article/details/79149977)

接下来按照一下命令编译重新安装Tensorflow

pip uninstall tensorflow
bazel clean
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
/tools/pip_package/build_pip_package.runfiles/
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
pip install /tmp/tensorflow_pkg/tensorflow-x.x.whl

测试Python接口并导出graph.pb, 新建一个export.py写入以下内容,如果执行成功并且在export.py同目录下生成graph.pb则证明Python接口测试成功!

#export.py
import tensorflow as tf
import numpy as np

with tf.Session() as sess:
    a = tf.Variable([1,2,3,4,5], name='a')
    b = tf.Variable([5,4,3,2,1], name='b')
    c = tf.user_ops.my_add(a, b, name="c")

    sess.run(tf.global_variables_initializer())

    print a.eval() # 5.0
    print b.eval() # 6.0
    print c.eval() # [0, 6, 6, 6, 6]

    tf.train.write_graph(sess.graph_def, './', 'graph.pb', as_text=False)

执行命令cp -r tensorflow/bazel-genfiles/tensorflow/cc/ops /usr/local/include/

tensorflow/core/user_ops/models新建一个test.cc文件,内容如下:

#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
#include "/usr/local/include/ops/user_ops.h"



using namespace tensorflow;
using namespace std;

int main(int argc, char* argv[]) {
  // Initialize a tensorflow session
  Session* session;
  Status status = NewSession(SessionOptions(), &session);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Read in the protobuf graph we exported
  // (The path seems to be relative to the cwd. Keep this in mind
  // when using `bazel run` since the cwd isn't where you call
  // `bazel run` but from inside a temp folder.)
  GraphDef graph_def;
  status = ReadBinaryProto(Env::Default(), "graph.pb", &graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Add the graph to the session
  status = session->Create(graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Setup inputs and outputs:

  // Our graph doesn't require any inputs, since it specifies default values,
  // but we'll change an input to demonstrate.
  Tensor a(DT_INT32, TensorShape({1, 2, 3, 4, 5}));
  auto input_A = a.tensor<int, 5>();
  input_A(0) = 1;
  input_A(1) = 2;
  input_A(2) = 3;
  input_A(3) = 4;
  input_A(4) = 5;



  Tensor b(DT_INT32, TensorShape({1, 2, 3, 4, 5}));
  auto input_B = b.tensor<int, 5>();
  input_B(0) = 1;
  input_B(1) = 2;
  input_B(2) = 3;
  input_B(3) = 4;
  input_B(4) = 5;
  // std::cout<<input_A<<std::endl;
  // std::cout<<input_B<<std::endl;
  // Tensor aa(DT_INT32, TensorShape());
  // aa.scalar<int32>()() = 3;
  //
  // Tensor bb(DT_INT32, TensorShape());
  // bb.scalar<int32>()() = 3;


  std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
    { "a", a },
    { "b", b },
  };

  // The session will initialize the outputs
  std::vector<tensorflow::Tensor> outputs;

  // Run the session, evaluating our "c" operation from the graph
  status = session->Run(inputs, {"c"}, {}, &outputs);
  if (!status.ok()) {
    std::cout <<"haha"<<std::endl;
    std::cout << status.ToString() << "\n";
    std::cout <<"hehe"<<std::endl;
    return 1;
  }

  // Grab the first output (we only evaluated one graph node: "c")
  // and convert the node to a scalar representation.
  auto output_c = outputs[0];
  std::cout <<"123"<<std::endl;

  // (There are similar methods for vectors and matrices here:
  // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)

  // Print the results
  std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30>
  //std::cout << output_c() << "\n"; // 30

  // Free any resources used by the session
  session->Close();
  return 0;
}

接下来在tensorflow/core/user_ops/目录下新建一个models目录,然后在tensorflow/core/user_ops/models新建一个BUILD文件,内容如下:

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

cc_binary(
    name = "test",
    srcs = ["test.cc"],
    deps = [
        "//tensorflow/core:tensorflow",
        "//tensorflow/cc:cc_ops",
        "//tensorflow/cc:client_session",
    ]
)

然后执行编译命令:bazel build -c opt --config=monolithic --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" //tensorflow/core/user_ops/models:test

然后把之前tensorflow/tensorflow/core/user_ops/models目录下的graph.pb复制到tensorflow/bazel-bin/tensorflow/core/user_ops/models目录下,执行./test`,即可输出:

2018-07-23 13:00:35.351963: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
123
Tensor<type: int32 shape: [1,2,3,4,5] values: [[[[0 4 6]]]]...>

参考:C++加载Tensorflow自带Op的Graph: https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值