ERROR: Creating graph in session failed...Not found: Op type not registered 'GatherTree' in binary

问题描述

tensorflow python训练模型,在c++项目中 load模型后,建图(tensorflow::Status status_create = session->Create(graphdef);)过程中出现下面报错:

ERROR: Creating graph in session failed…Not found: Op type not registered ‘GatherTree’ in binary running on aisp. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) tf.contrib.resampler should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed. load model failed Invalid argument: Session was not created with a graph before Run()!

解决方法

1、进入tensorflow目录,找到 /mymodules/tensorflow/tensorflow/contrib/seq2seq/kernels,将这个文件夹下面的 beam_search_ops.cc 和 beam_search_ops.h加入到自己的c++项目中
2、进入 /mymodules/tensorflow/tensorflow/contrib/seq2seq/ops,打开beam_search_ops.cc,将

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

拷贝到上面beam_search_ops.h中, 将beam_search_ops.cc的函数注册过程拷贝到上面的beam_search_ops.cc中。

最终结果

将下面的文件加入到自己的项目中
beam_search_ops.h

#ifndef TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
#define TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {
class OpKernelContext;

namespace functor {

template <typename Device, typename T>
struct GatherTree {
  void operator()(OpKernelContext* ctx, const Device& d,
                  typename TTypes<T, 3>::ConstTensor step_ids,
                  typename TTypes<T, 3>::ConstTensor parent_ids,
                  TTypes<int32>::ConstVec max_sequence_lengths,
                  const T end_token, typename TTypes<T, 3>::Tensor beams);
};

}  // namespace functor
}  // namespace tensorflow

#endif  // TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_

beam_search_ops.cc

#define EIGEN_USE_THREADS

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif  // GOOGLE_CUDA

#include "beam_search_ops.h"

#include <memory>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/work_sharder.h"

namespace tensorflow {
    
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;

REGISTER_OP("GatherTree")
    .Input("step_ids: T")
    .Input("parent_ids: T")
    .Input("max_sequence_lengths: int32")
    .Input("end_token: T")
    .Output("beams: T")
    .Attr("T: {int32}")
    .SetShapeFn([](InferenceContext* c) {
      ShapeHandle step_ids, parent_ids, max_sequence_lengths, end_token;

      // step_ids, parent_ids, and output are all shaped:
      //   [max_time, batch_size, beam_width].
      // max_sequence_length is shaped [batch_size] and end_token is a scalar.
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max_sequence_lengths));
      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &end_token));
      TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
      DimensionHandle batch_size = c->Dim(step_ids, 1);
      TF_RETURN_IF_ERROR(
          c->Merge(batch_size, c->Dim(max_sequence_lengths, 0), &batch_size));
      ShapeHandle step_ids_prefix = c->Matrix(c->Dim(step_ids, 0), batch_size);
      TF_RETURN_IF_ERROR(c->MergePrefix(step_ids, step_ids_prefix, &step_ids,
                                        &step_ids_prefix));

      c->set_output(0, step_ids);
      return tensorflow::Status::OK();
    })
    .Doc(R"doc(
Calculates the full beams from the per-step ids and parent beam ids.

On CPU, if an out of bound parent id is found, an error is returned.
On GPU, if an out of bound parent id is found, a -1 is stored in the
corresponding output value and the execution for that beam returns early.

For a given beam, past the time step containing the first decoded `end_token`
all values are filled in with `end_token`.

TODO(ebrevdo): fill in the remainder of this docstring.

step_ids: `[max_time, batch_size, beam_width]`.
parent_ids: `[max_time, batch_size, beam_width]`.
max_sequence_lengths: `[batch_size]`.
end_token: `[]`.
beams: `[max_time, batch_size, beam_width]`.
)doc");

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

template <typename Device, typename T>
class GatherTreeOp : public OpKernel {
 public:
  explicit GatherTreeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

  void Compute(OpKernelContext* ctx) override {
    const Device& device = ctx->eigen_device<Device>();
    const Tensor& step_ids = ctx->input(0);
    const Tensor& parent_ids = ctx->input(1);
    const Tensor& max_sequence_lengths = ctx->input(2);
    const Tensor& end_token = ctx->input(3);
    const TensorShape& step_ids_shape = step_ids.shape();
    OP_REQUIRES(
        ctx, step_ids_shape.dims() == 3,
        errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
                                step_ids_shape.DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()),
                errors::InvalidArgument(
                    "max_sequence_lengths must be a vector, saw shape: ",
                    max_sequence_lengths.shape().DebugString()));
    OP_REQUIRES(
        ctx, TensorShapeUtils::IsScalar(end_token.shape()),
        errors::InvalidArgument("end_token must be a scalar, saw shape: ",
                                end_token.shape().DebugString()));
    OP_REQUIRES(
        ctx, step_ids_shape == parent_ids.shape(),
        errors::InvalidArgument(
            "step_ids.shape must match parent_ids.shape.  but shapes are: ",
            step_ids_shape.DebugString(), " and ",
            parent_ids.shape().DebugString()));
    OP_REQUIRES(
        ctx,
        step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0),
        errors::InvalidArgument("batch size dimensions step_ids.shape[1] and "
                                "max_sequence_lengths.shape[0] must match.  "
                                "but shapes are: ",
                                step_ids_shape.DebugString(), " and ",
                                max_sequence_lengths.shape().DebugString()));
    Tensor* beams;
    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
    typename TTypes<T, 3>::ConstTensor step_ids_t(step_ids.tensor<T, 3>());
    typename TTypes<T, 3>::ConstTensor parent_ids_t(parent_ids.tensor<T, 3>());
    typename TTypes<int32>::ConstVec max_seq_lens_t =
        max_sequence_lengths.vec<int32>();
    typename TTypes<T>::ConstScalar end_token_t(end_token.scalar<T>());
    typename TTypes<T, 3>::Tensor beams_t(beams->tensor<T, 3>());
    const T end_token_value = end_token_t();
    functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
                                     max_seq_lens_t, end_token_value, beams_t);
  }
};

#define REGISTER_KERNEL(T)                                          \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("GatherTree").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      GatherTreeOp<CPUDevice, T>);
REGISTER_KERNEL(int32);
#undef REGISTER_KERNEL

namespace functor {

// CPU specialization
template <>
struct GatherTree<CPUDevice, int32> {
  void operator()(OpKernelContext* ctx, const CPUDevice& d,
                  TTypes<int32, 3>::ConstTensor step_ids,
                  TTypes<int32, 3>::ConstTensor parent_ids,
                  TTypes<int32>::ConstVec max_sequence_lengths,
                  const int32 end_token, TTypes<int32, 3>::Tensor beams) {
    const int32 max_time = parent_ids.dimension(0);
    const int32 batch_size = parent_ids.dimension(1);
    const int32 beam_width = parent_ids.dimension(2);
    beams.setConstant(end_token);

    auto DoWork = [&, ctx, end_token](int start_batch_beam,
                                      int limit_batch_beam) {
      for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
        const int32 batch = i / beam_width;
        const int32 beam = i % beam_width;
        const int32 max_seq_len_b =
            Eigen::numext::mini(max_time, max_sequence_lengths(batch));
        if (max_seq_len_b <= 0) {
          continue;
        }
        beams(max_seq_len_b - 1, batch, beam) =
            step_ids(max_seq_len_b - 1, batch, beam);
        int32 parent = parent_ids(max_seq_len_b - 1, batch, beam);
        for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
          if (parent < 0 || parent > beam_width) {
            ctx->SetStatus(
                errors::InvalidArgument("Saw invalid parent id ", parent,
                                        " at (batch, time, beam) == (", batch,
                                        ", ", level, ", ", beam, ")"));
            return;
          }
          beams(level, batch, beam) = step_ids(level, batch, parent);
          parent = parent_ids(level, batch, parent);
        }
        // Not necessary when using a BeamSearchDecoder, but necessary
        // when a user feeds in possibly broken trajectory (i.e., non-eos
        // entries in a beam following eos entries).
        bool finished = false;
        for (int32 time = 0; time < max_seq_len_b; ++time) {
          if (finished) {
            beams(time, batch, beam) = end_token;
          } else if (beams(time, batch, beam) == end_token) {
            finished = true;
          }
        }
      }
    };
    // Guesstimate of cost; ~5 lookup/store/compare per inner beam
    // traversal time step.
    const int64 batch_beam_cost =
        Eigen::TensorOpCost::DivCost<int32>() +
        6 * Eigen::TensorOpCost::AddCost<int32>() +
        2 * max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
    auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
    Shard(worker_threads.num_threads, worker_threads.workers,
          batch_size * beam_width, batch_beam_cost, DoWork);
  }
};

}  // namespace functor

#if GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T)                                            \
  template <>                                                          \
  void GatherTree<GPUDevice, T>::operator()(                           \
      OpKernelContext* ctx, const GPUDevice& d,                        \
      typename TTypes<T, 3>::ConstTensor step_ids,                     \
      typename TTypes<T, 3>::ConstTensor parent_ids,                   \
      TTypes<int32>::ConstVec max_sequence_lengths, const T end_token, \
      typename TTypes<T, 3>::Tensor beams);                            \
  extern template struct GatherTree<GPUDevice, T>;

DECLARE_GPU_SPEC(int32);
#undef DECLARE_GPU_SPEC
}  // end namespace functor

#define REGISTER_GPU_KERNEL(T)                          \
  REGISTER_KERNEL_BUILDER(Name("GatherTree")            \
                              .Device(DEVICE_GPU)       \
                              .TypeConstraint<T>("T")   \
                              .HostMemory("end_token"), \
                          GatherTreeOp<GPUDevice, T>);

REGISTER_GPU_KERNEL(int32);
#undef REGISTER_GPU_KERNEL
#endif  // GOOGLE_CUDA

}  // end namespace tensorflow
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值