- 用途:用python写TF项目时,可能要把数据流向C++文件处理后再给python,这就要自定义OP。也就是写个C++动态链接库,用自定义的OP让python和C++动态链接库连接起来。
- TF的版本要足够高,因为低版本的TF没有libtensorflow_framework.so。如tf1.2.1就没有,而tf1.5.0有。
- 出现类似undefined symbol: _ZN10tensorflow7strings6StrCatERKNS0_8AlphaNumES3_的错误,你可能在编译C++动态链接库时要加上或去掉add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)。
- 下面给出三类文件:
- 生成C++动态链接库的文件:zero_out1.cc、zero_out2.cc、zero_out3.cc、zero_out4.cc。这三个文件内容差不多。
- 编译C++动态链接库的配置文件:CMakeLists.txt。
- 调用自定义OP的两个python例子:1zero_out_op_test.py、4zero_out_op_test.py。这两个文件内容差不多。
a.生成C++动态链接库的文件,前3个都可以被1zero_out_op_test.py调用,主要区别是是否使用范型。
zero_out1.cc
#include "tensorflow/core/framework/op.h"
REGISTER_OP("ZeroOut") //specify the name of op
.Attr("preserve_index: int=0") //set a attribute and assign it a default value
.Input("to_zero: int32") //specify the name and type of input
.Output("zeroed: int32") //specify the name and type of output
/*.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})*/;
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
//get the attribute values, and assign it to preserve_index_
OP_REQUIRES_OK(context,
context->GetAttr("preserve_index", &preserve_index_));
//check whether the preserve_index is positive
OP_REQUIRES(context,
preserve_index_ >= 0,
errors::InvalidArgument("Need preserve_index >= 0, got ",preserve_index_));
}
void Compute(OpKernelContext* context) override {
//get tensor from context
const Tensor& input_tensor = context->input(0);
/* check the type of input_tensor:
* IsScalar: return shape.dims() == 0;
* IsVector: return shape.dims() == 1;
* IsVectorOrHigher:return shape.dims() >= 1;
* IsMatrix: return shape.dims() == 2;
* IsSquareMatrix: return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
* IsMatrixOrHigher:return shape.dims() >= 2;
*/
/*OP_REQUIRES(context,
TensorShapeUtils::IsVector(input_tensor.shape()),
errors::InvalidArgument("ZeroOut expects a 1-D vector."));*/
//auto input = input_tensor.template flat<int32>();
auto input = input_tensor.flat<int32>();
//check whether preserve_index is legal
OP_REQUIRES(context,
preserve_index_ < input.dimension(0),
errors::InvalidArgument("preserve_index out of range"));
//create a ouput_tensor, and allocate memory for it using context->allocate_ouput()
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
auto output = output_tensor->template flat<int32>();
//auto output = output_tensor->flat<int32>();
//Set all but the first element of the output tensor to 0
const int N = input.size();
for (int i = 1; i < N; i++) {
output(i) = 0;
}
//Preserve the first input value if possible
if (N > 0) output(0) = input(0);
//use attribute value to preserve the first input value
output(preserve_index_+1) = input(preserve_index_+1);
}
private:
int preserve_index_;
};
//register op(ZeroOut) and its implementation(ZeroOutOp) into Tensorflow System
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
zero_out2.cc
#include "tensorflow/core/framework/op.h"
REGISTER_OP("ZeroOut")
.Attr("T: {float, int32} = DT_INT32")
.Input("to_zero: T")
.Output("zeroed: T");
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutInt32Op : public OpKernel {
public:
explicit ZeroOutInt32Op(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
Tensor* output = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<int32>();
const int N = input.size();
for (int i = 1; i < N; i++)
output_flat(i) = 0;
if (N > 0) output_flat(0) = input(0);
}
};
class ZeroOutFloatOp : public OpKernel {
public:
explicit ZeroOutFloatOp(OpKernelConstruction * context) : OpKernel(context) {}
void Compute(OpKernelContext * context) override {
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<float>();
Tensor * output = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<float>();
const int N = input.size();
for (int i = 0; i < N; i++)
output_flat(i) = 0;
if (N > 0) output_flat(0) = input(0);
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<int32>("T"), ZeroOutInt32Op);
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<float>("T"), ZeroOutFloatOp);
zero_out3.cc
#include "tensorflow/core/framework/op.h"
REGISTER_OP("ZeroOut")
.Attr("T: {float, double, int32}")
.Input("to_zero: T")
.Output("zeroed: T");
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
template <typename T>
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<T>();
Tensor* output = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<T>();
const int N = input.size();
for (int i = 1; i < N; i++)
output_flat(i) = 0;
if (N > 0) output_flat(0) = input(0);
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<int32>("T"), ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<float>("T"), ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<double>("T"), ZeroOutOp<double>);
zero_out4.cc
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.Output("indice: int32");
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
Tensor* output_tensor = NULL;
Tensor* output_tensor_indice = NULL;
TensorShape indice_shape;
int d=3,d0=2,d1=3,d2=4,n=3;
int dims[/*d*/] = {d0,d1,d2};
TensorShapeUtils::MakeShape(dims, n, &indice_shape);//create a TensorShape whose dimensions are dims[0] dims[1] ... dims[n-1] (n<=d)
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
OP_REQUIRES_OK(context, context->allocate_output(1, indice_shape, &output_tensor_indice));
auto output_flat = output_tensor->flat<int32>();
auto indice_flat = output_tensor_indice->flat<int32>();
for(int i=0;i<d0*d1*d2;i++)//d0 * d1 * ... * dn(from 0 to n-1, but to d-1)
indice_flat(i) = 1+i;
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}
if (N > 0) output_flat(0) = input(0);
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
b.编译C++动态链接库的配置文件
cmake_minimum_required(VERSION 2.8)
project(zero_out)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
#add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
#add_definitions(-D_MWAITXINTRIN_H_INCLUDED)
#指出TF安装位置,这个例子用的是Anaconda中的TF
set(TF_PATH /home/<yourname>/software/anaconda2/envs/tf150_27/lib/python2.7/site-packages/tensorflow)
include_directories(
${TF_PATH}/include
${TF_PATH}/include/external/eigen_archive
${TF_PATH}/include/external/nsync/public
)
link_directories(
${TF_PATH}
)
add_library(${PROJECT_NAME}1 SHARED zero_out1.cc)
target_link_libraries(${PROJECT_NAME}1 tensorflow_framework)
add_library(${PROJECT_NAME}2 SHARED zero_out2.cc)
target_link_libraries(${PROJECT_NAME}2 tensorflow_framework)
add_library(${PROJECT_NAME}3 SHARED zero_out3.cc)
target_link_libraries(${PROJECT_NAME}3 tensorflow_framework)
add_library(${PROJECT_NAME}4 SHARED zero_out4.cc)
target_link_libraries(${PROJECT_NAME}4 tensorflow_framework)
c.调用自定义OP的两个python例子
1zero_out_op_test.py
import tensorflow as tf
class ZeroOutTest(tf.test.TestCase):
def testZeroOut(self):
zero_out_module = tf.load_op_library('build/libzero_out1.so')
with self.test_session():
result = zero_out_module.zero_out([5, 4, 3, 2, 1])
#self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
print "-------Result: "
print result.eval()
if __name__ == "__main__":
tf.test.main()
4zero_out_op_test.py
import tensorflow as tf
class ZeroOutTest(tf.test.TestCase):
def testZeroOut(self):
zero_out_module = tf.load_op_library('build/libzero_out4.so')
with self.test_session():
result = zero_out_module.zero_out([5, 4, 3, 2, 1])
#self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
print "-------Result[0]: "
print result[0].eval()
print "-------Result[1]: "
print result[1].eval()
if __name__ == "__main__":
tf.test.main()