tensorflow之kernel

tensorflow图结点叫OP(operator)。OP是C++写的可以由使用者任意扩展的。扩展OP分两步,1是OP的声明,也就OP注册,使用REGISTER_OP来完成。2是OP的实现,叫op_kernel。KERNEL也需要注册,叫REGISTER_KERNEL_BUILDER。OP在实现时需要继承OpKernel类。

构图时只需要OP声明即可。运行时才需要查找并实例化Kernel。一个OP在不同的设备上可以有不同的实现。下面的例子是官网最简单的ZeroOut OP声明和Kernel的实现。实际上,声明和实现完全可以独立在不同的文件。OP注册在tensorflow之op_wyg_031113的博客-CSDN博客中进行了详细的分析。本文则着重分析Kernel

Kernel是真正实现计算功能的。

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
//OP的声明
REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

//OP实现
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};
//注册KERNEL
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

Kernel接口

tensorflow/core/framework/op_kernel.h

同步计算 Compute方法

  1. kernel计算可以是同步也可以是异步。Compute必须是线程安全。大多数是同步。
  2. 同步 kernel 绝不能用锁,条件变量等阻塞当前线程,试图在其他kernel里解锁。有
  3. 因为executor可能只有固定数量的线程,都阻塞就会死锁
  4. 如果真想加锁,如RecvOp, DequeueOp,必须继承OpKernel的子类AsyncOpKernel。
  5. 大多数情况下,AsyncOpKerenl应当使用cancellation机制:context->cancellation_manager()
  6. op的输入输出都要通过参数OpKernelContext context来获得。返回状态也通过ctx->SetStatus()
  7. 同步计算中,context可以保证函数返回前直存在。
构造与析构
class OpKernel {
 public:
  //kernel不会在调度器中初始化,所以可以在子类中实现重逻辑
  explicit OpKernel(OpKernelConstruction* context);

  //允许延时OP. executor会使用OpKernelContext::inc_num_deferred_ops_function()` and
  // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
  OpKernel(OpKernelConstruction* context, bool is_deferred);
  //能请允许子类自定义NodeDef
  OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
           bool is_deferred);

  virtual ~OpKernel();
     
  //核心计算函数,子类重写它来实现自己的功能
  virtual void Compute(OpKernelContext* context) = 0;

  // Returns nullptr iff this op kernel is synchronous.
  virtual AsyncOpKernel* AsAsync() { return nullptr; }

  // Returns true iff this op kernel is considered "expensive". The
  // runtime may use this flag to optimize graph execution for example
  // to "inline" inexpensive kernels.
  virtual bool IsExpensive() { return expensive_; }

  // Returns a pointer to the tensor stored inside constant ops.
  virtual const Tensor* const_tensor() const { return nullptr; }

  // Accessors. 能返回结点定义,结点名字,
  const NodeDef& def() const { return props_->node_def; }
  const std::string& name() const { return props_->node_def.name(); }
  absl::string_view name_view() const { return name_view_; }
  const std::string& type_string() const { return props_->node_def.op(); }
  absl::string_view type_string_view() const { return type_string_view_; }
  const std::string& requested_input(int i) const {
    return props_->node_def.input(i);
  }
  const std::string& requested_device() const {
    return props_->node_def.device();
  }

  int num_inputs() const { return props_->input_types.size(); }
  DataType input_type(int i) const { return props_->input_types[i]; }
  const DataTypeVector& input_types() const { return props_->input_types; }
  const MemoryTypeVector& input_memory_types() const {
    return input_memory_types_;
  }

  int num_outputs() const { return props_->output_types.size(); }
  DataType output_type(int o) const { return props_->output_types[o]; }
  const DataTypeVector& output_types() const { return props_->output_types; }
  const MemoryTypeVector& output_memory_types() const {
    return output_memory_types_;
  }

  Status InputRange(StringPiece input_name, int* start, int* stop) const;
  Status OutputRange(StringPiece output_name, int* start, int* stop) const;

  // Returns `true` if and only if this kernel uses deferred execution.
  bool is_deferred() const { return is_deferred_; }

  // Returns a trace string for current computation, op name/type and input
  // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
  // should use the default implementation.
  virtual std::string TraceString(const OpKernelContext& ctx,
                                  bool verbose) const;

 protected:
  std::string ShapeTraceString(const OpKernelContext& ctx) const;

 private:
  const std::shared_ptr<const NodeProperties> props_;
  const MemoryTypeVector input_memory_types_;
  const MemoryTypeVector output_memory_types_;
  NameRangeMap input_name_map_;
  NameRangeMap output_name_map_;
  const absl::string_view name_view_;
  const absl::string_view type_string_view_;
  const int graph_def_version_;
  const bool is_deferred_;
  bool expensive_;

  TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
};

异步计算:AsyncOpKernel

class AsyncOpKernel : public OpKernel {
 public:
  using OpKernel::OpKernel;  // Lift OpKernel constructors.

  //异步计算完成后要调用此回调函数通知调度器。
  //只能调用一次,一旦调用,context, 和this都可能已经销毁了
  typedef std::function<void()> DoneCallback;
  //异步计算就重写此接口
  virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;

  AsyncOpKernel* AsAsync() override { return this; }

  void Compute(OpKernelContext* context) override;
};

Kernel构造时的OpKernelConstruction

传入了

  1. 设备:device
  2. 分配器Allocator
  3. 资源管理器:ResourceMgr
  4. Node
  5. Env
  6. FunctionLib

class OpKernelConstruction {
 public:
  OpKernelConstruction(DeviceType device_type, DeviceBase* device,
                       Allocator* allocator, FunctionLibraryRuntime* flib,
                       ResourceMgr* resource_mgr,
                       const std::shared_ptr<const NodeProperties>& props,
                       const MemoryTypeSlice& input_memory_types,
                       const MemoryTypeSlice& output_memory_types,
                       int graph_def_version, Status* status);

  Env* env() const { return device_->env(); }

  
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp);
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp, AllocatorAttributes allocator_attr);

  // User-supplied configuration of this operation.
  const NodeDef& def() const { return props_->node_def; }

  // For inspecting the inputs to this operation.
  int num_inputs() const { return props_->input_types.size(); }
  DataType input_type(int i) const { return props_->input_types[i]; }
  const DataTypeSlice& input_types() const { return props_->input_types_slice; }
  const MemoryTypeSlice& input_memory_types() const {
    return input_memory_types_;
  }

  // For inspecting the outputs expected from this operation.
  int num_outputs() const { return props_->output_types.size(); }
  DataType output_type(int i) const { return props_->output_types[i]; }
  const DataTypeSlice& output_types() const {
    return props_->output_types_slice;
  }
  const MemoryTypeSlice& output_memory_types() const {
    return output_memory_types_;
  }

  // If expected_inputs == inputs() and expected_outputs == output_types(),
  // returns OK, else returns INVALID_ARGUMENT with an error message.
  // Recommended for Ops with dynamic signatures.
  Status MatchSignature(const DataTypeSlice expected_inputs,
                        const DataTypeSlice expected_outputs);

  // For recording configuration errors during construction.
  void SetStatus(const Status& status);
  const Status& status() const { return *status_; }

  // Look up the attr with name attr_name and set *value to its value.  If no
  // attr with attr_name is found in def(), or the attr does not have
  // a matching type, a non-ok status will be returned.
  template <class T>
  Status GetAttr(StringPiece attr_name, T* value) const;

  // Return true if the attr_name is defined in def().
  bool HasAttr(StringPiece attr_name) const;

  // Return the device type.
  const DeviceType& device_type() const { return device_type_; }

  // If not nullptr, the kernel can instantiate functions defined in
  // the library. E.g.,
  // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
  FunctionLibraryRuntime* function_library() const { return flib_; }

  // Shared resources accessible to this kernel.
  ResourceMgr* resource_manager() const { return resource_mgr_; }

  // The GraphDef version whose behavior we should follow.
  int graph_def_version() const { return graph_def_version_; }

  // Helper routines for the OP_REQUIRES macros
  void CtxFailure(const Status& s);
  void CtxFailureWithWarning(const Status& s);
  void CtxFailure(const char* file, int line, const Status& s);
  void CtxFailureWithWarning(const char* file, int line, const Status& s);

  // Unrecommended functions: these are functions that have some
  // current uses but are not recommended for use, and may go away at
  // some future major version release.

  // May be used, e.g., to get GPU handles, etc.
  //
  // Currently only used to call MakeTensorFromProto() for
  // implementing ConstantOp for every device.  See comments
  // on Device::MakeTensorFromProto for longer-term replacement
  // ideas.
  DeviceBase* device() const { return device_; }

 private:
  const DeviceType device_type_;
  DeviceBase* const device_;
  Allocator* allocator_;
  FunctionLibraryRuntime* flib_;
  ResourceMgr* const resource_mgr_;
  std::shared_ptr<const NodeProperties> props_;
  MemoryTypeSlice input_memory_types_;
  MemoryTypeSlice output_memory_types_;
  const int graph_def_version_;
  Status* status_;

  // Allow access from OpKernel ctor.
  friend class OpKernel;

  TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
};

OP输入输出参数帮助类

有的输入是个List,用一个名字,代表了同类型的多个输入。 可以认为是Tensor  tensors[N].输出也有这种情况。

  1. OpInputList
  2. OpMutableInputList
  3. OpOutputList

Compute的参数OpKernelContext

这个类十分巨大,内容丰富。这个Context提供了Op Compute时所需要的一切。从逻辑上讲,可分为以下几类

输入输出参数获取

Input, Output. 至于Attr,是在构图时获得,OpKernelConstruction里就能获取

输出还涉及到Tensor内存分配

执行环境

env, device, resource_mgr, node, graph, session, step_id, function_library, allocator, session


class OpKernelContext {
 public:
  // The first element of a WrappedAllocator is a "base" Allocator and
  // the second element is that Allocator wrapped by a
  // TrackingAllocator
  typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;

  // TODO(zhifengc): Do some cleanup of Params.
  // The Params struct is passed in to initialize an OpKernelContext,
  // and must outlive the OpKernelContext.
  struct Params {
    ~Params() { delete eigen_gpu_device; }

    // The step being executed.
    int64_t step_id = 0;

    // Timestamp for the start of graph execution. Used for latency metrics.
    int64_t start_time_usecs = 0;

    // The deadline for the session to complete by. Empty if unspecified.
    absl::optional<absl::Time> deadline;

    // The op kernel being computed.
    OpKernel* op_kernel = nullptr;

    // The device on which the kernel is running.
    DeviceBase* device = nullptr;

    // The Eigen GPU device wrapper, which may include a per-op
    // wrapped allocator. The concrete type of this object depends on
    // the type of this->device, so eigen_gpu_device can't be an
    // inline member and must be heap allocated. However, we don't
    // want to allocate a new eigen_gpu_device for every Op that is
    // executed. Instead this member is allocated on first use using
    // ensure_eigen_gpu_device, and then if the Params structure is
    // re-used for subsequent Ops, the eigen_gpu_device is
    // ReInitialized in the OpKernelContext constructor. Unlike the
    // other pointers in Params, this one is owned by Params.
    PerOpGpuDevice* eigen_gpu_device = nullptr;

    inline void ensure_eigen_gpu_device() {
      DCHECK(device);
      if (nullptr == eigen_gpu_device) {
        // Surprisingly, MakeGpuDevice will return nullptr if the
        // device is not a GPU device. This is ok, since those devices
        // will never use eigen_gpu_device. It seems better to have
        // ensure_eigen_gpu_device fall through and regenerate the
        // nullptr every time an OpKernelContext is instantiated, than
        // to do an unnecessary allocation of a dummy eigen GPU
        // device for CPU device Ops.
        eigen_gpu_device = device->MakeGpuDevice();
      }
    }

    bool track_allocations = false;
    bool log_memory = false;

    // Array indexed by output number for this node
    const AllocatorAttributes* output_attr_array = nullptr;

    // Shared resources accessible by this op kernel invocation.
    ResourceMgr* resource_manager = nullptr;

    // Per-step resources accessible by this op kernel invocation should be
    // stored in this container..
    ScopedStepContainer* step_container = nullptr;

    // Mechanism used by this op kernel invocation to communicate with
    // computations running on other devices.
    RendezvousInterface* rendezvous = nullptr;

    // Mechanism for executing a collective op that needs to coordinate
    // with parallel instances running on other devices.
    CollectiveExecutor* collective_executor = nullptr;

    // The session state for this op.
    SessionState* session_state = nullptr;

    // Unique session identifier. Can be empty.
    std::string session_handle;

    // Metadata about the session. Can be nullptr.
    const SessionMetadata* session_metadata = nullptr;

    // The tensor store for this op.
    TensorStore* tensor_store = nullptr;

    // Mechanism used by this op kernel invocation to register a callback
    // for its cancellation.
    CancellationManager* cancellation_manager = nullptr;

    // Inputs to this op kernel.
    const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
    bool is_input_dead = false;

    const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
        nullptr;

    // Device context.
    DeviceContext* op_device_context = nullptr;

    // Control-flow op supports.
    FrameAndIter frame_iter;

    // Function call supports.
    CallFrameInterface* call_frame = nullptr;
    FunctionLibraryRuntime* function_library = nullptr;
    std::function<void(std::function<void()>)>* runner = nullptr;
    StepStatsCollectorInterface* stats_collector = nullptr;
    GraphCollector* graph_collector = nullptr;
    bool run_all_kernels_inline = false;
    const std::string* executor_type = nullptr;

    // TensorSliceReaderCache support.
    checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;

    // Support for forwarding reservations (used by ScopedAllocator).
    static constexpr int kNeverForward = -2;
    static constexpr int kNoReservation = -1;
    // Values in [0,...) represent reservations for the indexed output.
    const int* forward_from_array = nullptr;

    // For tracking actively running deferred ops.
    std::function<void()> inc_num_deferred_ops_function;
    std::function<void()> dec_num_deferred_ops_function;

    absl::optional<ManagedStackTrace> stack_trace = {};

    // For implementing `OpKernelContext::output_required()`. If null, all
    // outputs are required.
    bool* outputs_required_array = nullptr;

    // For access to distributed coordination service.
    CoordinationServiceAgent* coordination_service_agent = nullptr;
  };

  // params must outlive the OpKernelContext.
  explicit OpKernelContext(Params* params);
  OpKernelContext(Params* params, int num_outputs);
  ~OpKernelContext();

  Env* env() const { return params_->device->env(); }

  int64_t step_id() const { return params_->step_id; }

  int64_t start_time_usecs() const { return params_->start_time_usecs; }

  // The deadline for the session to complete by. Empty if unspecified in
  // RunOptions.
  absl::optional<absl::Time> deadline() const { return params_->deadline; }

  const OpKernel& op_kernel() const { return *params_->op_kernel; }

  // Stack trace of where the op was defined (if defined in eager mode).
  const absl::optional<ManagedStackTrace>& stack_trace() const {
    return params_->stack_trace;
  }

  // Input/output signature.

  int num_inputsconst { return params_->inputs->size(); }
  DataType input_dtype(int index) const;
  Status input_dtype(StringPiece name, DataType* dtype) const;
  MemoryType input_memory_type(int index) const;

  int num_outputs() const { return outputs_.size(); }
  DataType expected_output_dtype(int index) const;
  MemoryType output_memory_type(int index) const;

  // Input

  // Returns an immutable input tensor. May only be used for non-Ref
  // inputs. For Ref inputs use mutable_input below.
  // REQUIRES: !IsRefType(input_dtype(index))
  // TODO(mrry): Convert this to return Status.
  const Tensor& input(int index) const;

  // Returns the named immutable input tensor in "tensor", as defined
  // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
  // use mutable_input below.
  // REQUIRES: !IsRefType(input_dtype(index))
  // REQUIRES: the named input must not be a list.
  Status input(StringPiece name, const Tensor** tensor);

  // Returns the named list-valued immutable input in "list", as
  // defined in the OpDef.  If the named output is not list-valued,
  // returns a one-element list. May only be used for non-Ref
  // inputs. For Ref inputs use mutable_input below.
  // REQUIRES: !IsRefType(input_dtype(index))
  Status input_list(StringPiece name, OpInputList* list);

  // For mutable inputs, use the following together to make sure there
  // is no concurrent access to mutable_input(), e.g.:
  // {
  //   Tensor& t = context->mutable_input(index);
  //   mutex_lock lock(*context->input_ref_mutex(index));
  //   // modify the values in t
  // }
  // REQUIRES: IsRefType(input_dtype(index))
  Status input_ref_mutex(StringPiece name, mutex** out_mutex);

  // Returns a mutable input tensor. Must be used to access Ref
  // inputs.  REQUIRES: IsRefType(input_dtype(index)). The caller may
  // modify the values stored in the Tensor buffer, and modifications
  // will be visible to other Ops reading the same ref tensor. If
  // !lock_held the input mutex will be acquired before returning the
  // Tensor.
  // TODO(mrry): Convert this to return Status.
  Tensor mutable_input(int index, bool lock_held);

  // Returns the named mutable input tensor in "tensor", as defined in
  // the OpDef. Must be used to access Ref inputs. The values stored
  // in the Tensor buffer may be modified, and modifications will be
  // visible to other Ops reading the same ref tensor. If !lock_held
  // the input mutex will be acquired before returning the Tensor.
  // REQUIRES: the named input must not be a list.
  // REQUIRES: the named input must be a ref tensor.
  Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);

  // Returns the named list-valued mutable input in "list", as defined
  // in the OpDef.  If the named input is not list-valued, returns a
  // one-element list. Must be used to access Ref inputs. The values
  // stored in the Tensor buffer may be modified, and modifications
  // will be visible to other Ops reading the same ref tensor.
  // REQUIRES: the named input must be a ref tensor.
  Status mutable_input_list(StringPiece name, OpMutableInputList* list);

  // Replace the corresponding Ref Input to use the storage buffer
  // used by tensor. If !lock_held the input mutex will be acquired
  // before returning the Tensor.
  // REQUIRES: IsRefType(input_dtype(index)).
  void replace_ref_input(int index, const Tensor& tensor, bool lock_held);

  // Replace the corresponding named Ref Input to use the storage
  // buffer used by tensor. If !lock_held the input mutex will be
  // acquired before returning the Tensor.
  // REQUIRES: IsRefType(input_dtype(index)).
  Status replace_ref_input(StringPiece name, const Tensor& tensor,
                           bool lock_held);

  // Deletes the Tensor object used as the Ref Input at
  // input_index. This is not usually necessary and should be used
  // with caution. If !lock_held the input mutex will be acquired
  // before returning the Tensor.
  // REQUIRES: IsRefType(input_dtype(input_index)).
  void delete_ref_input(int input_index, bool lock_held);

  // Return true if there is input at the given index. An operator has no
  // input at index if its tensor is null. This is primarily used by the
  // merge operator.
  // TODO(mrry): Convert this to return Status.
  bool has_input(int index) const;

  // Returns true if all inputs are the same shape, otherwise sets the
  // status to a non-OK value and returns false.
  // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
  bool ValidateInputsAreSameShape(OpKernel* op);

  // If non-null, kernels should populate with any partition subgraphs created.
  GraphCollector* graph_collector() { return params_->graph_collector; }

  // If True, hint that all kernels in functions called by this kernel, should
  // be treated as "inexpensive", and hence executed on the scheduling thread.
  bool run_all_kernels_inline() const {
    return params_->run_all_kernels_inline;
  }

  // Returns the registered name for the executor type that is executing the
  // current kernel. If empty, the default executor is used.
  const std::string& executor_type() const;

  // Input to output forwarding.

  // Set the output Ref Tensor at output_index to be an alias of the
  // input Ref Tensor at input_index.
  // REQUIRES: IsRefType(input_dtype(input_index)).
  // REQUIRES: IsRefType(output_dtype(output_index)).
  void forward_ref_input_to_ref_output(int input_index, int output_index);

  // Returns true when an alias to input[input_index], reshaped to output_shape,
  // which is safe to use for in-place computation was written to *output.
  // Returns false if input[input_index] has a refcount greater than one, or if
  // its type does not match the expected output type of output[output_index],
  // or the number of elements in input[input_index] does not equal the number
  // of elements in output_shape.
  bool forward_input_to_output_with_shape(int input_index, int output_index,
                                          const TensorShape& output_shape,
                                          Tensor** output) TF_MUST_USE_RESULT;
  Status forward_input_to_output_with_shape(StringPiece input_name,
                                            StringPiece output_name,
                                            const TensorShape& output_shape,
                                            Tensor** output) TF_MUST_USE_RESULT;

  // Returns a pointer to a Tensor aliasing the underlying buffer backing
  // input[input_index] iff
  //   * input[input_index] is not a ref,
  //   * the data type, shape, memory type, and allocator attributes of
  //     input[input_index] are compatible with those given in dtype, shape,
  //     memory_type, and attr,
  //   * refcount on the underlying buffer is one.
  //   * Either there is no forwarding reservation for either input_index
  //     or output_index or the specified input is reserved for the specified
  //     output. More precisely:
  //
  //     These cases mean neither input nor output has a reservation:
  //        forward_from_array = nullptr
  //     OR (input_index is not in forward_from_array AND
  //         (output_index == kNoReservation OR
  //          forward_from_array[output_index] == kNoReservation))
  //
  //     This case means that input_index is reserved for output_index:
  //        forward_from_array[output_index] == input_index
  //
  //     This case means the output is reserved to always be allocated,
  //     never assigned a forwarded input:
  //        forward_from_array[output_index] == kNeverForward
  //
  // Otherwise returns nullptr.
  // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
  // forwarding is only safe if there are no reads via __ldg() after writes
  // to the same address.
  std::unique_ptr<Tensor> forward_input(
      int input_index, int output_index, DataType output_dtype,
      const TensorShape& output_shape, MemoryType output_memory_type,
      const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;

  // Tries to forward one of the inputs given in input_indices to
  // output[output_index]. If none of the given inputs can be forwarded, calls
  // allocate_output() to allocate a new output buffer. The index of the
  // forwarded input will be assign to output argument forwarded_input (if it's
  // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
  // -1.
  Status forward_input_or_allocate_output(
      gtl::ArraySlice<int> candidate_input_indices, int output_index,
      const TensorShape& output_shape, Tensor** output,
      int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
  Status forward_input_or_allocate_output(
      gtl::ArraySlice<StringPiece> candidate_input_names,
      StringPiece output_name, const TensorShape& output_shape,
      Tensor** output) TF_MUST_USE_RESULT;

  // Tries to reuse one of the inputs given in input_indices as a temporary.
  // If none of the given inputs can be forwarded, calls
  // allocate_temp() to allocate a new temporary buffer.
  Status forward_input_or_allocate_temp(
      gtl::ArraySlice<int> candidate_input_indices, DataType type,
      const TensorShape& shape, const AllocatorAttributes& allocator_attr,
      Tensor* out_temp) TF_MUST_USE_RESULT;

  Status forward_input_or_allocate_temp(
      gtl::ArraySlice<int> candidate_input_indices, DataType type,
      const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
    return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
                                          AllocatorAttributes(), out_temp);
  }

  // Output

  // Returns the named list-valued output in "list", as defined in the OpDef.
  // If the named output is not list-valued, returns a one-element list.
  Status output_list(StringPiece name, OpOutputList* list);

  // If output_required(index) returns true, the OpKernel's Compute() method
  // should call allocate_output(index, ...), set_output(index, ...),
  // set_output_ref(index, ...), or set the status to a non-ok value.
  // If it returns false, it may output, but is not required to do so.
  bool output_required(int index) const {
    return !params_->outputs_required_array ||
           params_->outputs_required_array[index];
  }

  // If output_expects_forwarding returns true, the OpKernel's Compute() method
  // should not allocate the output with allocate_output but instead needs to
  // use forward_input.
  bool output_expects_forwarding(int index) const {
    return params_->forward_from_array != nullptr &&
           params_->forward_from_array[index] >= 0;
  }

  // Allocation of tensors during kernel execution inside the Compute
  // method:
  //
  // There are two methods to allocate Tensors when an Op kernel
  // executes.
  //
  // 1) allocate_output. This should be used to allocate any tensor
  // that is going to be used as an output from the Op at the end of
  // the current execution. The caller indicates which output the
  // Tensor will be assigned to, and the call returns the
  // newly-allocated Tensor. The Tensor can subsequently be assigned
  // to during kernel execution, and will be used as the designated
  // output when the kernel execution completes.
  //
  // 2) allocate_temp. This should be used to allocate any scratch
  // storage that is needed while the kernel is executing, and will
  // not be retained by the Op.
  //
  // In some cases a Tensor needs to be used as an output even though
  // it was previously allocated elsewhere. The Tensor may have been
  // passed as an input, or stored in a Tensor during a
  // previous kernel execution, or allocated earlier in the kernel
  // execution at a time when it was not known which output it would
  // be assigned to. In this case the kernel can use set_output or
  // set_output_ref to indicate that the tensor should be used as the
  // designated output. It is legal to use any previously-allocated
  // Tensor as an argument to set_output or set_output_ref, including
  // Tensors allocated via allocate_temp. There may be a performance
  // penalty to using a Tensor that was not allocated using
  // allocate_output. This is because allocate_output uses the
  // AllocatorAttributes stored in output_attr_array for the
  // designated output. In some cases, using the wrong attributes may
  // cause an extra copy of the Tensor's buffer.

  // Allocates output for the specified output index with shape.
  // OpKernelContext retains ownership of the returned pointer. See
  // comment above.
  //
  // If memory allocation fails, returns an error status.
  //
  // REQUIRES: !IsRefType(expected_output_dtype(index))
  Status allocate_output(int index, const TensorShape& shape,
                         Tensor** tensor) TF_MUST_USE_RESULT;
  Status allocate_output(StringPiece name, const TensorShape& shape,
                         Tensor** tensor) TF_MUST_USE_RESULT;
  // The following methods use the supplied attributes instead of
  // those in output_attr_array. The caller is responsible for
  // ensuring that the attributes are "compatible" with the
  // output_attr_array, e.g. the tensor is allocated on the correct
  // device. See comment above.
  Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
                         AllocatorAttributes attr) TF_MUST_USE_RESULT;
  Status allocate_output(StringPiece name, const TensorShape& shape,
                         Tensor** tensor,
                         AllocatorAttributes attr) TF_MUST_USE_RESULT;

  // Allocates a temporary Tensor of the specified type and
  // shape. Devices such as GPUs that enqueue Ops for lazy execution
  // may retain references to the temporary tensors after the Op's
  // Compute method has run. See comment above.
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp, AllocatorAttributes allocator_attr,
                       const AllocationAttributes& allocation_attr);
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp, AllocatorAttributes allocator_attr) {
    return allocate_temp(type, shape, out_temp, allocator_attr,
                         AllocationAttributes());
  }
  Status allocate_temp(DataType type, const TensorShape& shape,
                       Tensor* out_temp) {
    return allocate_temp(type, shape, out_temp, AllocatorAttributes());
  }

  // Copies a tensor (allocated by the caller) to the specified output
  // index.  REQUIRES: !IsRefType(expected_output_dtype(index))
  // REQUIRES: 'tensor' must have the same MemoryType as
  // output_memory_types[index]. See comment above.
  Status set_output(StringPiece name, const Tensor& tensor);
  Status set_output(StringPiece name, Tensor&& tensor);
  void set_output(int index, const Tensor& tensor);
  void set_output(int index, Tensor&& tensor);

  // To output a reference.  Caller retains ownership of mu and tensor_for_ref,
  // and they must outlive all uses within the step. See comment above.
  // REQUIRES: IsRefType(expected_output_dtype(index))
  Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);

  // Returns nullptr if allocate_output() or set_output() have not been called.
  Status mutable_output(StringPiece name, Tensor** tensor);

  // Return the DeviceContext that should be used for this Op.
  //
  // If using the templated function, the type must be a subclass
  // of DeviceContext.
  //
  // Returns nullptr if the device did not provide one.
  template <typename T>
  T* op_device_context();
  DeviceContext* op_device_context() {
    DeviceContext* ret = params_->op_device_context;
    if (ret == nullptr) {
      auto* dev_info = device()->tensorflow_accelerator_device_info();
      if (dev_info) ret = dev_info->default_context;
    }
    return ret;
  }

  AllocatorAttributes input_alloc_attr(int index) const {
    if (params_->input_alloc_attrs == nullptr) {
      return AllocatorAttributes();
    } else {
      DCHECK_GE(index, 0);
      DCHECK_LT(index, params_->input_alloc_attrs->size());
      return (*params_->input_alloc_attrs)[index];
    }
  }

  AllocatorAttributes output_alloc_attr(int index) const {
    return params_->output_attr_array[index];
  }

  gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
    gtl::InlinedVector<WrappedAllocator, 4> retrieved;
    if (tracking_state_) {
      mutex_lock lock(tracking_state_->mu);
      retrieved.swap(tracking_state_->wrapped_allocators);
    }
    return retrieved;
  }

  // Communication.
  //
  // An op kernel communicates with outside environment through
  // Rendezvous Send() and Recv().
  RendezvousInterface* rendezvous() const { return params_->rendezvous; }

  CollectiveExecutor* collective_executor() const {
    return params_->collective_executor;
  }

  // An op kernel can access the session state it belongs to.
  SessionState* session_state() const { return params_->session_state; }

  // Unique identifier of the session it belongs to. Can be empty.
  std::string session_handle() const { return params_->session_handle; }

  // Metadata about the session. Can be nullptr.
  const SessionMetadata* session_metadata() const {
    return params_->session_metadata;
  }

  // An op kernel can access the tensor store of the run it belongs to.
  TensorStore* tensor_store() const { return params_->tensor_store; }

  // Function call support.
  //
  // If this kernel invocation is within a function execution,
  // call_frame() returns the call frame for the function call.
  CallFrameInterface* call_frame() const { return params_->call_frame; }

  // If not nullptr, the kernel invoke functions defined in the
  // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
  FunctionLibraryRuntime* function_library() const {
    return params_->function_library;
  }

  std::function<void(std::function<void()>)>* runner() const {
    return params_->runner;
  }
  StepStatsCollectorInterface* stats_collector() const {
    return params_->stats_collector;
  }

  // Shared resources accessible to this kernel.
  ResourceMgr* resource_manager() const { return params_->resource_manager; }

  checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
    return params_->slice_reader_cache;
  }

  // Execution.
  //
  // OpKernels can use these eigen devices to carry out their
  // numerical computation.
  const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
    return *device()->eigen_cpu_device();
  }
  const Eigen::GpuDevice& eigen_gpu_device() const {
    return params_->eigen_gpu_device->device();
  }
  template <typename EigenDeviceType>
  const EigenDeviceType& eigen_device() const;

  // Error handling.

  // If expected_inputs == inputs() and expected_outputs == output_types(),
  // returns OK, else returns INVALID_ARGUMENT with an error message.
  // Recommended for Ops with dynamic signatures, where validation can only
  // be performed at runtime.
  Status MatchSignature(const DataTypeSlice expected_inputs,
                        const DataTypeSlice expected_outputs);

  // An OpKernel should call SetStatus() if Compute() encounters an
  // error.
  void SetStatus(const Status& status);
  const Status& status() const { return status_; }

  // Cancellation.
  //
  // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
  // example of how to use this API.
  CancellationManager* cancellation_manager() const {
    return params_->cancellation_manager;
  }

  // Other accessors.

  // For control flow.
  FrameAndIter frame_iter() const { return params_->frame_iter; }
  bool is_input_dead() const { return params_->is_input_dead; }

  // May be used, e.g., to get GPU handles, etc.
  // TODO(tucker): Add example usage.
  DeviceBase* device() const { return params_->device; }

  // Per-step container for use by white-listed internal ops.
  ScopedStepContainer* step_container() const {
    return params_->step_container;
  }

  // Access to distributed coordination service.
  CoordinationServiceAgent* coordination_service_agent() const {
    return params_->coordination_service_agent;
  }

  // Helper routines for the OP_REQUIRES macros
  void CtxFailure(const Status& s);
  void CtxFailureWithWarning(const Status& s);
  void CtxFailure(const char* file, int line, const Status& s);
  void CtxFailureWithWarning(const char* file, int line, const Status& s);

  // Unrecommended functions: these are functions that have some
  // current uses but are not recommended for use, and may go away at
  // some future major version release.
  //
  // The following functions all have versions that return Status
  // to capture error conditions, and are strongly preferred.
  Tensor* mutable_output(int index);
  mutex* input_ref_mutex(int index);
  void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
  TensorValue release_output(int index);

  bool track_allocations() const { return params_->track_allocations; }

  // Records temp memory allocation. Tensor object is recorded to identify the
  // case where temp memory is used as output memory.
  void record_temp_memory_allocation(int64_t size, const Tensor& t)
      TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);

  // Returns recorded size of temporary memory;
  int64_t temp_memory_allocated() const
      TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);

  // Records persistent memory allocation, size can be negative indicating
  // deallocation.
  void record_persistent_memory_allocation(int64_t size, int64_t alloc_id = -1)
      TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);

  // Returns recorded size and ids of persistent memory.
  int64_t persistent_memory_allocated() const
      TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);

  std::vector<int64_t> persistent_alloc_ids() const
      TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);

  // Resets counters for temp and persistent memory and recorded ids.
  void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);

  bool input_is_ref(int index) const;

  void set_record_memory_consumption(bool v);

  // Used by OpKernel implementations to track actively running deferred ops.
  //
  // A deferred op is one whose Compute method returns (or whose ComputeAsync
  // method invokes the callback) when work is scheduled onto a device. At that
  // point, we don't know when the work will actually complete (or if it has
  // already completed) on the device. These functions allow the executor to
  // track the status of deferred ops and act accordingly.
  //
  // Deferred OpKernel implementations must use these methods to get two
  // functions. It then must call these two functions in pairs, before and after
  // device execution, respectively.
  TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
    DCHECK(params_->op_kernel->is_deferred());
    return params_->inc_num_deferred_ops_function
               ? params_->inc_num_deferred_ops_function
               : []() {};
  }
  TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
    DCHECK(params_->op_kernel->is_deferred());
    return params_->dec_num_deferred_ops_function
               ? params_->dec_num_deferred_ops_function
               : []() {};
  }

  Allocator* get_allocator(AllocatorAttributes attr);

 private:
  bool record_memory_consumption_ = false;

  // Internal common method used when allocating tensor memory
  Status allocate_tensor(DataType type, const TensorShape& shape,
                         Tensor* out_tensor,
                         AllocatorAttributes allocator_attr) {
    return allocate_tensor(type, shape, out_tensor, allocator_attr,
                           AllocationAttributes());
  }

  Status allocate_tensor(DataType type, const TensorShape& shape,
                         Tensor* out_tensor, AllocatorAttributes allocator_attr,
                         const AllocationAttributes& allocation_attr);

  // Helpers for `set_output()`.

  // Returns `true` if the tensor was copied into an allocated output.
  bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);

  void maybe_track_allocations_for_set_output(const Tensor& tensor);

  Status get_input_index(StringPiece name, int* out_index) const;
  Status get_output_index(StringPiece name, int* out_index) const;

  // Initialize the allocated_scope_ids_ set the first time this method is
  // called.
  void maybe_initialize_scope_id_set();

  Status status_;
  friend class CollectiveExecutor;  // for access to params_
  Params* params_;                  // not owned
  gtl::InlinedVector<TensorValue, 4> outputs_;

  // Keep track of calls to ScopedAllocator.
  // TODO(ayushd): change to absl::flat_hash_set.
  std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;

  // The following data members are only used when allocation tracking is
  // enabled, memory consumption is being recorded, or tensor access is being
  // recorded.
  struct TrackingState {
    mutable mutex mu;
    gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
        TF_GUARDED_BY(mu);

    mutable mutex stats_mu;
    int64_t temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;

    int64_t persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
    gtl::InlinedVector<std::pair<const void*, int64_t>, 2>
        temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu);
    gtl::InlinedVector<int64_t, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu);
  };
  std::unique_ptr<TrackingState> tracking_state_;

  // For access to `params_->op_kernel`.
  friend void CheckNotInComputeAsync(OpKernelContext* ctx,
                                     const char* correct_macro_name);

  TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
};

kernel实例化

运行时调用如下方法创建Kernel


// Instantiate an OpKernel that has been registered.  Returns nullptr
// if no operation for that type of device / input signature combination
// (and a NOT_FOUND *status), or there is an error in construction (and
// an INVALID_ARGUMENT *status).  Otherwise, the caller takes ownership
// of the returned pointer.
// EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
                                         DeviceBase* device,
                                         Allocator* allocator,
                                         const NodeDef& node_def,
                                         int graph_def_version, Status* status);

std::unique_ptr<OpKernel> CreateOpKernel(
    DeviceType device_type, DeviceBase* device, Allocator* allocator,
    const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
    Status* status);

Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
                      Allocator* allocator, FunctionLibraryRuntime* flib,
                      const std::shared_ptr<const NodeProperties>& props,
                      int graph_def_version, OpKernel** kernel);

Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
                      Allocator* allocator, FunctionLibraryRuntime* flib,
                      ResourceMgr* resource_mgr,
                      const std::shared_ptr<const NodeProperties>& props,
                      int graph_def_version, OpKernel** kernel);

Kernel注册

官方给的例子


// Register your OpKernel by specifying the Op's name, the device the
// kernel runs on, any type attr constraints for this kernel, any
// host-memory args, and the class to instantiate.  Examples:
//
//  // A kernel that supports all types.
//  REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
//
//  // The following are equivalent ways of specifying that the kernel only
//  // works if the "T" type attr is set to DT_FLOAT.
//  REGISTER_KERNEL_BUILDER(
//      Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
//      SubOp<float>);
//  // (You would then repeat this for every type supported by "Sub".)
//
//  // This form allows you to specify a list of types as the constraint.
//  REGISTER_KERNEL_BUILDER(Name("Sub")
//                              .Device(DEVICE_CPU)
//                              .TypeConstraint("T", {DT_FLOAT}),
//                          SubOp<float>);
//
//  // A kernel that expects one of the input tensors in host memory.
//  REGISTER_KERNEL_BUILDER(
//      Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
//
// See kernel_def_builder for details.

Kernel注册同样使用了宏,工厂等

REGISTER_KERNEL_BUILDER流程分析

REGISTER_KERNEL_BUILDER 调用了REGISTER_KERNEL_BUILDER_IMPL 调用了TF_EXTRACT_KERNEL_NAME 调用了TF_EXTRACT_KERNEL_NAME_IMPL 调用了REGISTER_KERNEL_BUILDER_IMPL_2 调用了TF_NEW_ID_FOR_INIT
调用了REGISTER_KERNEL_BUILDER_IMPL_3

// REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
// TODO(dodgen): There are some uses of this macro inside functions, where
// kernel_builder refers to (non-const) locals (they should be fixed). To
// accommodate those, kernel_builder.Build() appears as an argument to an
// immediately-called lambda (not in the lambda itself).
#define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr,   \
                                       is_system_kernel, ...)               \
  static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr      \
      TF_ATTRIBUTE_UNUSED =                                                 \
          TF_INIT_ON_STARTUP_IF(is_system_kernel ||                         \
                                (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
                                 SHOULD_REGISTER_OP(op_name)))              \
          << ([](::tensorflow::KernelDef const* kernel_def) {               \
               也就是到这里了,使用kernel_factory来注册了一个lambda函数      \
               ::tensorflow::kernel_factory::OpKernelRegistrar registrar(   \
                   kernel_def, #__VA_ARGS__,                                \
                   [](::tensorflow::OpKernelConstruction* context)          \
                       -> ::tensorflow::OpKernel* {                         \
                     return new __VA_ARGS__(context); 这里就是在new ZeroOut               \
                   });                                                      \
               (void)registrar;                                             \
               return ::tensorflow::InitOnStartupMarker{};                  \
             })(kernel_builder_expr.Build()); //这里的kernel_builder_expr就是KernelDefBuilder,其实就是Name("ZeroOut").Device(DEVICE_CPU).Build(); 而且这里是对lambda函数的调用,所以会立即进入函数内

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);这个定义中,Name实际上是KernelDefBuilder. Device就是KernelDefBuilder::Device. 

REGISTER_KERNEL_BUILDER( KernelDefBuilder对象, ZeroOut这个类)。

OpkernelRegistrar的构建函数里最终调用到这个GlobalKernelRegistry的Reigster

void* GlobalKernelRegistry() {
  static KernelRegistry* global_kernel_registry = []() {
    KernelRegistry* registry = new KernelRegistry;
    OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
    return registry;
  }();
  return global_kernel_registry;
}

struct KernelRegistry {
  mutex mu;
  std::unordered_multimap<string, KernelRegistration> registry //就是放在这个map里了
      TF_GUARDED_BY(mu);
};

// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
namespace register_kernel {

class Name : public KernelDefBuilder {
 public:
  explicit Name(const char* op) : KernelDefBuilder(op) {}
};

}  // namespace register_kernel

// Kernel registration appears as:
//   REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl)
// We'd like to have "OpName" as a constant-expression, without requiring that
// of the overall KernelDefBuilder expression (beginning with the
// register_kernel::Name constructor above).
//
// So, we pull the "OpName" part to a separate macro-level argument. This
// involves treating Name("OpName") as a macro call, via token-pasting (e.g.
// M_## =>  M_Name("OpName")), and having it expand to '"OpName",
// Name("OpName")' which is then usable as two arguments.
#define TF_EXTRACT_KERNEL_NAME_Name(name_str) \
  name_str, ::tensorflow::register_kernel::Name(name_str)
#define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__)
#define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...)                    \
  TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \
                              __VA_ARGS__)

// REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
// TODO(dodgen): There are some uses of this macro inside functions, where
// kernel_builder refers to (non-const) locals (they should be fixed). To
// accommodate those, kernel_builder.Build() appears as an argument to an
// immediately-called lambda (not in the lambda itself).
#define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr,   \
                                       is_system_kernel, ...)               \
  static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr      \
      TF_ATTRIBUTE_UNUSED =                                                 \
          TF_INIT_ON_STARTUP_IF(is_system_kernel ||                         \
                                (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
                                 SHOULD_REGISTER_OP(op_name)))              \
          << ([](::tensorflow::KernelDef const* kernel_def) {               \
               ::tensorflow::kernel_factory::OpKernelRegistrar registrar(   \
                   kernel_def, #__VA_ARGS__,                                \
                   [](::tensorflow::OpKernelConstruction* context)          \
                       -> ::tensorflow::OpKernel* {                         \
                     return new __VA_ARGS__(context);                       \
                   });                                                      \
               (void)registrar;                                             \
               return ::tensorflow::InitOnStartupMarker{};                  \
             })(kernel_builder_expr.Build());

// REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name,
// kernel_builder_expr.
#define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \
                                       is_system_kernel, ...)        \
  TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name,        \
                     kernel_builder_expr, is_system_kernel, __VA_ARGS__)

// REGISTER_KERNEL_BUILDER, but with is_system_kernel bound.
#define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \
  TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder,    \
                         is_system_kernel, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
  TF_ATTRIBUTE_ANNOTATE("tf:kernel")                 \
  REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__)

// The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
// `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
// unconditionally even when selective registration is used.
#define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
  TF_ATTRIBUTE_ANNOTATE("tf:kernel")                        \
  TF_ATTRIBUTE_ANNOTATE("tf:kernel:system")                 \
  REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__)

// Checks whether a given kernel is registered on device_type.
bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);

// If node of node_name, experimental_debug_info, node_op, node_device and
// node_attrs has a corresponding kernel registered on device_type, returns OK
// and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.
Status FindKernelDef(
    const DeviceType& device_type, StringPiece node_name,
    bool has_experimental_debug_info,
    const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
    StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
    const KernelDef** def, std::string* kernel_class_name);

// If node_def has a corresponding kernel registered on device_type,
// returns OK and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
                     const KernelDef** def, std::string* kernel_class_name);

// Writes a list of all registered kernels to LOG(INFO), to help users debug
// missing kernel errors.
void LogAllRegisteredKernels();

// Gets a list of all registered kernels.
KernelList GetAllRegisteredKernels();

// Gets a list of all registered kernels for which predicate returns true
KernelList GetFilteredRegisteredKernels(
    const std::function<bool(const KernelDef&)>& predicate);

// Gets a list of all registered kernels for a given op
KernelList GetRegisteredKernelsForOp(StringPiece op_name);

namespace kernel_factory {

// OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
// them. You register factories with the TensorFlow core by constructing an
// OpKernelRegistrar and passing the factory as a constructor parameter.
class OpKernelFactory {
 public:
  virtual OpKernel* Create(OpKernelConstruction* context) = 0;
  virtual ~OpKernelFactory() = default;
};

class OpKernelRegistrar {
 public:
  // Registers the given kernel factory with TensorFlow. TF will call the
  // factory Create() method when it determines that a kernel matching the given
  // KernelDef is required.
  OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    std::unique_ptr<OpKernelFactory> factory) {
    InitInternal(kernel_def, kernel_class_name, std::move(factory));
  }

  // Registers the given factory function with TensorFlow. This is equivalent
  // to registering a factory whose Create function invokes `create_fn`.
  OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    OpKernel* (*create_fn)(OpKernelConstruction*)) {
    InitInternal(kernel_def, kernel_class_name,
                 absl::make_unique<PtrOpKernelFactory>(create_fn));
  }

 private:
  struct PtrOpKernelFactory : public OpKernelFactory {
    explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
        : create_func_(create_func) {}

    OpKernel* Create(OpKernelConstruction* context) override;

    OpKernel* (*create_func_)(OpKernelConstruction*);
  };

  void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    std::unique_ptr<OpKernelFactory> factory);
};

}  // namespace kernel_factory

从动态库加载kernel

tensorflow/core/framework/op_kernel.cc

加载目录:tensorflow/core/kernels目录中的所有so。实际上使用了Env->LoadDynamicLibrary


void LoadDynamicKernelsInternal() {
  Env* env = Env::Default();

  // Override to allow loading unsafe packages for development.
  // DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
  char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES");
  bool override_abi_check = false;
  if (_abi_check_env_var != nullptr) {
    override_abi_check = strcmp(_abi_check_env_var, "1") == 0;
  }

  string bazel_kernel_dir =
      io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
  std::vector<string> files;
  Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
  if (s_kernel_dir.ok()) {
    string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
    for (const auto& file : files) {
      string fullpath = io::JoinPath(bazel_kernel_dir, file);
      if (env->MatchPath(fullpath, dll_spec)) {
        Status s = IsProbablySafeToLoad(fullpath);
        if (!s.ok() && override_abi_check) {
          LOG(WARNING) << "Loading UNSAFE library " << fullpath
                       << " because ABI check override is set: "
                       << s.error_message();
        }
        if (s.ok() || override_abi_check) {
          // TODO(gunan): Store the handles to the opened files.
          void* unused_filehandle;
          TF_CHECK_OK(
              env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle));
        } else {
          LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
                       << s.error_message();
        }
      }
    }
  }
}

// Mechanism for loading existing kernel libraries.
void LoadDynamicKernels() {
  // TODO(gunan): As more features are available, add intelligent kernel
  // selection, and dropping unsuitable kernel logic here.
  static absl::once_flag dll_loader_flag;
  absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
}

 Kernel从context中获取输入,分配输出时返回错误

tensorflow/core/framework/op_requires.h中定义了大量的宏,帮助我们实现这些功能。这些宏能根据需要返回错误

#define OP_REQUIRES_OK(CTX, ...)                             \
  do {                                                       \
    ::tensorflow::Status _s(__VA_ARGS__);                    \
    if (!TF_PREDICT_TRUE(_s.ok())) {                         \
      CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
      (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s);  \
      return;                                                \
    }                                                        \
  } while (0)

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值