tensorflow创建动态库文件来自定义op

目前正在学习tensorflow自定义OP,刚学会如何添加和添加简单的op代码。

预备技能
对 C++ 有一定了解.
已经下载 TensorFlow 源代码并有能力编译它.

第一步:找一个文件夹存放你要编译的文件my_add.cc并调用 REGISTER_OP 宏来定义 Op 的接口.
该OP接受两个int32 类型tensor 作为 输入,并将这两个tensor进行求和并将第一位置0输出出来。

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

REGISTER_OP("MyAdd")
    .Input("x: int32")
    .Input("y: int32")
    .Output("z: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      c->set_output(0, c->input(1));
      return Status::OK();
    });

class MyAddOp : public OpKernel {
 public:
  explicit MyAddOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& a = context->input(0);
    const Tensor& b = context->input(1);
    auto A = a.flat<int32>();
    auto B = b.flat<int32>();
    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = A.size();

    for (int i = 1; i < N; i++) {
      output_flat(i) = A(i)+B(i);
      output_flat(0) = 0;
    }
  }
};
REGISTER_KERNEL_BUILDER(Name("MyAdd").Device(DEVICE_CPU), MyAddOp);

第二步:cmake编译
在你当前存放my_add.cc目录下执行下面命令编译.so文件:

TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
g++ -std=c++11 -shared my_add.cc -o my_add.so -fPIC -I$TF_INC -I$TF_INC/external/nsync/public -L$TF_LIB -ltensorflow_framework -O2

如果你使用的是python3.5:

TF_INC=$(python3.5 -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
TF_LIB=$(python3.5 -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
g++ -std=c++11 -shared my_add.cc -o my_add.so -fPIC -I$TF_INC -I$TF_INC/external/nsync/public -L$TF_LIB -ltensorflow_framework -O2

我的gcc版本是5.4,可以正常编译,编译完成后会在当前文件夹下产生my_add.so文件,记住当前文件存放路径。
如果提示无法找到op.h文件,请把tensorflow源代码里面的op.h复制到你的tensorflow安装目录当中。
第三步:使用该OP

import tensorflow as tf

so_file = '你的文件路径/my_add.so'

if __name__ == "__main__":
  #tf.test.main()
  my_add_module = tf.load_op_library(so_file)
  out = my_add_module.my_add([5, 4, 3, 2, 1],[1, 2, 3, 4, 5])
  sess = tf.Session()
  result = sess.run(out)
  print(result)

输出[0, 6, 6, 6, 6],成功!

参考:
https://blog.csdn.net/xiangxianghehe/article/details/81002227
https://cloud.tencent.com/developer/section/1475696

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值