『写在前面』
以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。
基本的流程
1. 定义Op接口
#include "tensorflow/core/framework/op.h"
REGISTER_OP("Custom")
.Input("custom_input: int32")
.Output("custom_output: int32");
2. 为Op实现Compute操作(CPU)或实现kernel(GPU)
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class CustomOp : public OpKernel{
public:
explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// 获取输入 tensor.
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat();
// 创建一个输出 tensor.
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat();
//进行具体的运算,操作input和output
//……
}
};
3. 将实现的kernel注册到TensorFlow系统中
REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);
CTCBeamSearchDecoder自定义
该Op对应TensorFlow中的源码部分
Op接口的定义:
tensorflow-master/tensorflow/core/ops/ctc_ops.cc
CTCBeamSearchDecoder本身的定义:
tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc
Op-Class的封装与Op注册:
tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc
基于源码修改的Op
#include
#include
#include
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
namespace tf = tensorflow;
using tf::shape_inference::DimensionHandle;
using tf::shape_inference::InferenceContext;
using tf::shape_inference::ShapeHandle;
using namespace tensorflow;
REGISTER_OP("CTCBeamSearchDecoderWithParam")
.Input("inputs: float")
.Input("sequence_length: int32")
.Attr("beam_width: int >= 1")
.Attr("top_paths: int >= 1")
.Attr("merge_repeated: bool = true")
//新添加了两个参数
.Attr("label_selection_size: int >= 0 = 0")
.Attr("label_selection_margin: float")
.Output("decoded_indices: top_paths * int64")
.Output("decoded_values: top_paths * int64")
.Output("decoded_shape: top_paths * int64")
.Output("log_probability: float")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle sequence_length;
TF_RETURN_IF_ERROR(c->W