目前正在学习tensorflow自定义OP,刚学会如何添加和添加简单的op代码。
预备技能
对 C++ 有一定了解.
已经下载 TensorFlow 源代码并有能力编译它.
第一步:找一个文件夹存放你要编译的文件my_add.cc并调用 REGISTER_OP 宏来定义 Op 的接口.
该OP接受两个int32 类型tensor 作为 输入,并将这两个tensor进行求和并将第一位置0输出出来。
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("MyAdd")
.Input("x: int32")
.Input("y: int32")
.Output("z: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output(0, c->input(1));
return Status::OK();
});
class MyAddOp : public OpKernel {
public:
explicit MyAddOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
auto A = a.flat<int32>();
auto B = b.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = A.size();
for (int i = 1; i < N; i++) {
output_flat(i) = A(i)+B(i);
output_flat(0) = 0;
}
}
};
REGISTER_KERNEL_BUILDER(Name("MyAdd").Device(DEVICE_CPU), MyAddOp);
第二步:cmake编译
在你当前存放my_add.cc目录下执行下面命令编译.so文件:
TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
g++ -std=c++11 -shared my_add.cc -o my_add.so -fPIC -I$TF_INC -I$TF_INC/external/nsync/public -L$TF_LIB -ltensorflow_framework -O2
如果你使用的是python3.5:
TF_INC=$(python3.5 -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
TF_LIB=$(python3.5 -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
g++ -std=c++11 -shared my_add.cc -o my_add.so -fPIC -I$TF_INC -I$TF_INC/external/nsync/public -L$TF_LIB -ltensorflow_framework -O2
我的gcc版本是5.4,可以正常编译,编译完成后会在当前文件夹下产生my_add.so文件,记住当前文件存放路径。
如果提示无法找到op.h文件,请把tensorflow源代码里面的op.h复制到你的tensorflow安装目录当中。
第三步:使用该OP
import tensorflow as tf
so_file = '你的文件路径/my_add.so'
if __name__ == "__main__":
#tf.test.main()
my_add_module = tf.load_op_library(so_file)
out = my_add_module.my_add([5, 4, 3, 2, 1],[1, 2, 3, 4, 5])
sess = tf.Session()
result = sess.run(out)
print(result)
输出[0, 6, 6, 6, 6],成功!
参考:
https://blog.csdn.net/xiangxianghehe/article/details/81002227
https://cloud.tencent.com/developer/section/1475696