tensorflow自定义op_TensorFlow实现自定义Op方式

『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口

#include "tensorflow/core/framework/op.h"

REGISTER_OP("Custom")

.Input("custom_input: int32")

.Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class CustomOp : public OpKernel{

public:

explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}

void Compute(OpKernelContext* context) override {

// 获取输入 tensor.

const Tensor& input_tensor = context->input(0);

auto input = input_tensor.flat();

// 创建一个输出 tensor.

Tensor* output_tensor = NULL;

OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),

&output_tensor));

auto output = output_tensor->template flat();

//进行具体的运算,操作input和output

//……

}

};

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op

#include

#include

#include

#include "tensorflow/core/util/ctc/ctc_beam_search.h"

#include "tensorflow/core/framework/op.h"

#include "tensorflow/core/framework/op_kernel.h"

#include "tensorflow/core/framework/shape_inference.h"

#include "tensorflow/core/kernels/bounds_check.h"

namespace tf = tensorflow;

using tf::shape_inference::DimensionHandle;

using tf::shape_inference::InferenceContext;

using tf::shape_inference::ShapeHandle;

using namespace tensorflow;

REGISTER_OP("CTCBeamSearchDecoderWithParam")

.Input("inputs: float")

.Input("sequence_length: int32")

.Attr("beam_width: int >= 1")

.Attr("top_paths: int >= 1")

.Attr("merge_repeated: bool = true")

//新添加了两个参数

.Attr("label_selection_size: int >= 0 = 0")

.Attr("label_selection_margin: float")

.Output("decoded_indices: top_paths * int64")

.Output("decoded_values: top_paths * int64")

.Output("decoded_shape: top_paths * int64")

.Output("log_probability: float")

.SetShapeFn([](InferenceContext* c) {

ShapeHandle inputs;

ShapeHandle sequence_length;

TF_RETURN_IF_ERROR(c->W

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值