tensorflow图结点叫OP(operator)。OP是C++写的可以由使用者任意扩展的。扩展OP分两步,1是OP的声明,也就OP注册,使用REGISTER_OP来完成。2是OP的实现,叫op_kernel。KERNEL也需要注册,叫REGISTER_KERNEL_BUILDER。OP在实现时需要继承OpKernel类。
构图时只需要OP声明即可。运行时才需要查找并实例化Kernel。一个OP在不同的设备上可以有不同的实现。下面的例子是官网最简单的ZeroOut OP声明和Kernel的实现。实际上,声明和实现完全可以独立在不同的文件。OP注册在tensorflow之op_wyg_031113的博客-CSDN博客中进行了详细的分析。本文则着重分析Kernel
Kernel是真正实现计算功能的。
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
//OP的声明
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
//OP实现
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value if possible.
if (N > 0) output_flat(0) = input(0);
}
};
//注册KERNEL
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
Kernel接口
tensorflow/core/framework/op_kernel.h
同步计算 Compute方法
- kernel计算可以是同步也可以是异步。Compute必须是线程安全。大多数是同步。
- 同步 kernel 绝不能用锁,条件变量等阻塞当前线程,试图在其他kernel里解锁。有
- 因为executor可能只有固定数量的线程,都阻塞就会死锁
- 如果真想加锁,如RecvOp, DequeueOp,必须继承OpKernel的子类AsyncOpKernel。
- 大多数情况下,AsyncOpKerenl应当使用cancellation机制:context->cancellation_manager()
- op的输入输出都要通过参数OpKernelContext context来获得。返回状态也通过ctx->SetStatus()
- 同步计算中,context可以保证函数返回前直存在。
构造与析构
class OpKernel {
public:
//kernel不会在调度器中初始化,所以可以在子类中实现重逻辑
explicit OpKernel(OpKernelConstruction* context);
//允许延时OP. executor会使用OpKernelContext::inc_num_deferred_ops_function()` and
// `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
OpKernel(OpKernelConstruction* context, bool is_deferred);
//能请允许子类自定义NodeDef
OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
bool is_deferred);
virtual ~OpKernel();
//核心计算函数,子类重写它来实现自己的功能
virtual void Compute(OpKernelContext* context) = 0;
// Returns nullptr iff this op kernel is synchronous.
virtual AsyncOpKernel* AsAsync() { return nullptr; }
// Returns true iff this op kernel is considered "expensive". The
// runtime may use this flag to optimize graph execution for example
// to "inline" inexpensive kernels.
virtual bool IsExpensive() { return expensive_; }
// Returns a pointer to the tensor stored inside constant ops.
virtual const Tensor* const_tensor() const { return nullptr; }
// Accessors. 能返回结点定义,结点名字,
const NodeDef& def() const { return props_->node_def; }
const std::string& name() const { return props_->node_def.name(); }
absl::string_view name_view() const { return name_view_; }
const std::string& type_string() const { return props_->node_def.op(); }
absl::string_view type_string_view() const { return type_string_view_; }
const std::string& requested_input(int i) const {
return props_->node_def.input(i);
}
const std::string& requested_device() const {
return props_->node_def.device();
}
int num_inputs() const { return props_->input_types.size(); }
DataType input_type(int i) const { return props_->input_types[i]; }
const DataTypeVector& input_types() const { return props_->input_types; }
const MemoryTypeVector& input_memory_types() const {
return input_memory_types_;
}
int num_outputs() const { return props_->output_types.size(); }
DataType output_type(int o) const { return props_->output_types[o]; }
const DataTypeVector& output_types() const { return props_->output_types; }
const MemoryTypeVector& output_memory_types() const {
return output_memory_types_;
}
Status InputRange(StringPiece input_name, int* start, int* stop) const;
Status OutputRange(StringPiece output_name, int* start, int* stop) const;
// Returns `true` if and only if this kernel uses deferred execution.
bool is_deferred() const { return is_deferred_; }
// Returns a trace string for current computation, op name/type and input
// tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
// should use the default implementation.
virtual std::string TraceString(const OpKernelContext& ctx,
bool verbose) const;
protected:
std::string ShapeTraceString(const OpKernelContext& ctx) const;
private:
const std::shared_ptr<const NodeProperties> props_;
const MemoryTypeVector input_memory_types_;
const MemoryTypeVector output_memory_types_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
const absl::string_view name_view_;
const absl::string_view type_string_view_;
const int graph_def_version_;
const bool is_deferred_;
bool expensive_;
TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
};
异步计算:AsyncOpKernel
class AsyncOpKernel : public OpKernel {
public:
using OpKernel::OpKernel; // Lift OpKernel constructors.
//异步计算完成后要调用此回调函数通知调度器。
//只能调用一次,一旦调用,context, 和this都可能已经销毁了
typedef std::function<void()> DoneCallback;
//异步计算就重写此接口
virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
AsyncOpKernel* AsAsync() override { return this; }
void Compute(OpKernelContext* context) override;
};
Kernel构造时的OpKernelConstruction
传入了
- 设备:device
- 分配器Allocator
- 资源管理器:ResourceMgr
- Node
- Env
- FunctionLib
class OpKernelConstruction {
public:
OpKernelConstruction(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib,
ResourceMgr* resource_mgr,
const std::shared_ptr<const NodeProperties>& props,
const MemoryTypeSlice& input_memory_types,
const MemoryTypeSlice& output_memory_types,
int graph_def_version, Status* status);
Env* env() const { return device_->env(); }
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp);
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp, AllocatorAttributes allocator_attr);
// User-supplied configuration of this operation.
const NodeDef& def() const { return props_->node_def; }
// For inspecting the inputs to this operation.
int num_inputs() const { return props_->input_types.size(); }
DataType input_type(int i) const { return props_->input_types[i]; }
const DataTypeSlice& input_types() const { return props_->input_types_slice; }
const MemoryTypeSlice& input_memory_types() const {
return input_memory_types_;
}
// For inspecting the outputs expected from this operation.
int num_outputs() const { return props_->output_types.size(); }
DataType output_type(int i) const { return props_->output_types[i]; }
const DataTypeSlice& output_types() const {
return props_->output_types_slice;
}
const MemoryTypeSlice& output_memory_types() const {
return output_memory_types_;
}
// If expected_inputs == inputs() and expected_outputs == output_types(),
// returns OK, else returns INVALID_ARGUMENT with an error message.
// Recommended for Ops with dynamic signatures.
Status MatchSignature(const DataTypeSlice expected_inputs,
const DataTypeSlice expected_outputs);
// For recording configuration errors during construction.
void SetStatus(const Status& status);
const Status& status() const { return *status_; }
// Look up the attr with name attr_name and set *value to its value. If no
// attr with attr_name is found in def(), or the attr does not have
// a matching type, a non-ok status will be returned.
template <class T>
Status GetAttr(StringPiece attr_name, T* value) const;
// Return true if the attr_name is defined in def().
bool HasAttr(StringPiece attr_name) const;
// Return the device type.
const DeviceType& device_type() const { return device_type_; }
// If not nullptr, the kernel can instantiate functions defined in
// the library. E.g.,
// CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
FunctionLibraryRuntime* function_library() const { return flib_; }
// Shared resources accessible to this kernel.
ResourceMgr* resource_manager() const { return resource_mgr_; }
// The GraphDef version whose behavior we should follow.
int graph_def_version() const { return graph_def_version_; }
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);
void CtxFailureWithWarning(const Status& s);
void CtxFailure(const char* file, int line, const Status& s);
void CtxFailureWithWarning(const char* file, int line, const Status& s);
// Unrecommended functions: these are functions that have some
// current uses but are not recommended for use, and may go away at
// some future major version release.
// May be used, e.g., to get GPU handles, etc.
//
// Currently only used to call MakeTensorFromProto() for
// implementing ConstantOp for every device. See comments
// on Device::MakeTensorFromProto for longer-term replacement
// ideas.
DeviceBase* device() const { return device_; }
private:
const DeviceType device_type_;
DeviceBase* const device_;
Allocator* allocator_;
FunctionLibraryRuntime* flib_;
ResourceMgr* const resource_mgr_;
std::shared_ptr<const NodeProperties> props_;
MemoryTypeSlice input_memory_types_;
MemoryTypeSlice output_memory_types_;
const int graph_def_version_;
Status* status_;
// Allow access from OpKernel ctor.
friend class OpKernel;
TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
};
OP输入输出参数帮助类
有的输入是个List,用一个名字,代表了同类型的多个输入。 可以认为是Tensor tensors[N].输出也有这种情况。
- OpInputList
- OpMutableInputList
- OpOutputList
Compute的参数OpKernelContext
这个类十分巨大,内容丰富。这个Context提供了Op Compute时所需要的一切。从逻辑上讲,可分为以下几类
输入输出参数获取
Input, Output. 至于Attr,是在构图时获得,OpKernelConstruction里就能获取
输出还涉及到Tensor内存分配
执行环境
env, device, resource_mgr, node, graph, session, step_id, function_library, allocator, session
class OpKernelContext {
public:
// The first element of a WrappedAllocator is a "base" Allocator and
// the second element is that Allocator wrapped by a
// TrackingAllocator
typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
// TODO(zhifengc): Do some cleanup of Params.
// The Params struct is passed in to initialize an OpKernelContext,
// and must outlive the OpKernelContext.
struct Params {
~Params() { delete eigen_gpu_device; }
// The step being executed.
int64_t step_id = 0;
// Timestamp for the start of graph execution. Used for latency metrics.
int64_t start_time_usecs = 0;
// The deadline for the session to complete by. Empty if unspecified.
absl::optional<absl::Time> deadline;
// The op kernel being computed.
OpKernel* op_kernel = nullptr;
// The device on which the kernel is running.
DeviceBase* device = nullptr;
// The Eigen GPU device wrapper, which may include a per-op
// wrapped allocator. The concrete type of this object depends on
// the type of this->device, so eigen_gpu_device can't be an
// inline member and must be heap allocated. However, we don't
// want to allocate a new eigen_gpu_device for every Op that is
// executed. Instead this member is allocated on first use using
// ensure_eigen_gpu_device, and then if the Params structure is
// re-used for subsequent Ops, the eigen_gpu_device is
// ReInitialized in the OpKernelContext constructor. Unlike the
// other pointers in Params, this one is owned by Params.
PerOpGpuDevice* eigen_gpu_device = nullptr;
inline void ensure_eigen_gpu_device() {
DCHECK(device);
if (nullptr == eigen_gpu_device) {
// Surprisingly, MakeGpuDevice will return nullptr if the
// device is not a GPU device. This is ok, since those devices
// will never use eigen_gpu_device. It seems better to have
// ensure_eigen_gpu_device fall through and regenerate the
// nullptr every time an OpKernelContext is instantiated, than
// to do an unnecessary allocation of a dummy eigen GPU
// device for CPU device Ops.
eigen_gpu_device = device->MakeGpuDevice();
}
}
bool track_allocations = false;
bool log_memory = false;
// Array indexed by output number for this node
const AllocatorAttributes* output_attr_array = nullptr;
// Shared resources accessible by this op kernel invocation.
ResourceMgr* resource_manager = nullptr;
// Per-step resources accessible by this op kernel invocation should be
// stored in this container..
ScopedStepContainer* step_container = nullptr;
// Mechanism used by this op kernel invocation to communicate with
// computations running on other devices.
RendezvousInterface* rendezvous = nullptr;
// Mechanism for executing a collective op that needs to coordinate
// with parallel instances running on other devices.
CollectiveExecutor* collective_executor = nullptr;
// The session state for this op.
SessionState* session_state = nullptr;
// Unique session identifier. Can be empty.
std::string session_handle;
// Metadata about the session. Can be nullptr.
const SessionMetadata* session_metadata = nullptr;
// The tensor store for this op.
TensorStore* tensor_store = nullptr;
// Mechanism used by this op kernel invocation to register a callback
// for its cancellation.
CancellationManager* cancellation_manager = nullptr;
// Inputs to this op kernel.
const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
bool is_input_dead = false;
const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
nullptr;
// Device context.
DeviceContext* op_device_context = nullptr;
// Control-flow op supports.
FrameAndIter frame_iter;
// Function call supports.
CallFrameInterface* call_frame = nullptr;
FunctionLibraryRuntime* function_library = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
StepStatsCollectorInterface* stats_collector = nullptr;
GraphCollector* graph_collector = nullptr;
bool run_all_kernels_inline = false;
const std::string* executor_type = nullptr;
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
// Support for forwarding reservations (used by ScopedAllocator).
static constexpr int kNeverForward = -2;
static constexpr int kNoReservation = -1;
// Values in [0,...) represent reservations for the indexed output.
const int* forward_from_array = nullptr;
// For tracking actively running deferred ops.
std::function<void()> inc_num_deferred_ops_function;
std::function<void()> dec_num_deferred_ops_function;
absl::optional<ManagedStackTrace> stack_trace = {};
// For implementing `OpKernelContext::output_required()`. If null, all
// outputs are required.
bool* outputs_required_array = nullptr;
// For access to distributed coordination service.
CoordinationServiceAgent* coordination_service_agent = nullptr;
};
// params must outlive the OpKernelContext.
explicit OpKernelContext(Params* params);
OpKernelContext(Params* params, int num_outputs);
~OpKernelContext();
Env* env() const { return params_->device->env(); }
int64_t step_id() const { return params_->step_id; }
int64_t start_time_usecs() const { return params_->start_time_usecs; }
// The deadline for the session to complete by. Empty if unspecified in
// RunOptions.
absl::optional<absl::Time> deadline() const { return params_->deadline; }
const OpKernel& op_kernel() const { return *params_->op_kernel; }
// Stack trace of where the op was defined (if defined in eager mode).
const absl::optional<ManagedStackTrace>& stack_trace() const {
return params_->stack_trace;
}
// Input/output signature.
int num_inputsconst { return params_->inputs->size(); }
DataType input_dtype(int index) const;
Status input_dtype(StringPiece name, DataType* dtype) const;
MemoryType input_memory_type(int index) const;
int num_outputs() const { return outputs_.size(); }
DataType expected_output_dtype(int index) const;
MemoryType output_memory_type(int index) const;
// Input
// Returns an immutable input tensor. May only be used for non-Ref
// inputs. For Ref inputs use mutable_input below.
// REQUIRES: !IsRefType(input_dtype(index))
// TODO(mrry): Convert this to return Status.
const Tensor& input(int index) const;
// Returns the named immutable input tensor in "tensor", as defined
// in the OpDef. May only be used for non-Ref inputs. For Ref inputs
// use mutable_input below.
// REQUIRES: !IsRefType(input_dtype(index))
// REQUIRES: the named input must not be a list.
Status input(StringPiece name, const Tensor** tensor);
// Returns the named list-valued immutable input in "list", as
// defined in the OpDef. If the named output is not list-valued,
// returns a one-element list. May only be used for non-Ref
// inputs. For Ref inputs use mutable_input below.
// REQUIRES: !IsRefType(input_dtype(index))
Status input_list(StringPiece name, OpInputList* list);
// For mutable inputs, use the following together to make sure there
// is no concurrent access to mutable_input(), e.g.:
// {
// Tensor& t = context->mutable_input(index);
// mutex_lock lock(*context->input_ref_mutex(index));
// // modify the values in t
// }
// REQUIRES: IsRefType(input_dtype(index))
Status input_ref_mutex(StringPiece name, mutex** out_mutex);
// Returns a mutable input tensor. Must be used to access Ref
// inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
// modify the values stored in the Tensor buffer, and modifications
// will be visible to other Ops reading the same ref tensor. If
// !lock_held the input mutex will be acquired before returning the
// Tensor.
// TODO(mrry): Convert this to return Status.
Tensor mutable_input(int index, bool lock_held);
// Returns the named mutable input tensor in "tensor", as defined in
// the OpDef. Must be used to access Ref inputs. The values stored
// in the Tensor buffer may be modified, and modifications will be
// visible to other Ops reading the same ref tensor. If !lock_held
// the input mutex will be acquired before returning the Tensor.
// REQUIRES: the named input must not be a list.
// REQUIRES: the named input must be a ref tensor.
Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
// Returns the named list-valued mutable input in "list", as defined
// in the OpDef. If the named input is not list-valued, returns a
// one-element list. Must be used to access Ref inputs. The values
// stored in the Tensor buffer may be modified, and modifications
// will be visible to other Ops reading the same ref tensor.
// REQUIRES: the named input must be a ref tensor.
Status mutable_input_list(StringPiece name, OpMutableInputList* list);
// Replace the corresponding Ref Input to use the storage buffer
// used by tensor. If !lock_held the input mutex will be acquired
// before returning the Tensor.
// REQUIRES: IsRefType(input_dtype(index)).
void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
// Replace the corresponding named Ref Input to use the storage
// buffer used by tensor. If !lock_held the input mutex will be
// acquired before returning the Tensor.
// REQUIRES: IsRefType(input_dtype(index)).
Status replace_ref_input(StringPiece name, const Tensor& tensor,
bool lock_held);
// Deletes the Tensor object used as the Ref Input at
// input_index. This is not usually necessary and should be used
// with caution. If !lock_held the input mutex will be acquired
// before returning the Tensor.
// REQUIRES: IsRefType(input_dtype(input_index)).
void delete_ref_input(int input_index, bool lock_held);
// Return true if there is input at the given index. An operator has no
// input at index if its tensor is null. This is primarily used by the
// merge operator.
// TODO(mrry): Convert this to return Status.
bool has_input(int index) const;
// Returns true if all inputs are the same shape, otherwise sets the
// status to a non-OK value and returns false.
// Usage: if (!context->ValidateInputsAreSameShape(this)) return;
bool ValidateInputsAreSameShape(OpKernel* op);
// If non-null, kernels should populate with any partition subgraphs created.
GraphCollector* graph_collector() { return params_->graph_collector; }
// If True, hint that all kernels in functions called by this kernel, should
// be treated as "inexpensive", and hence executed on the scheduling thread.
bool run_all_kernels_inline() const {
return params_->run_all_kernels_inline;
}
// Returns the registered name for the executor type that is executing the
// current kernel. If empty, the default executor is used.
const std::string& executor_type() const;
// Input to output forwarding.
// Set the output Ref Tensor at output_index to be an alias of the
// input Ref Tensor at input_index.
// REQUIRES: IsRefType(input_dtype(input_index)).
// REQUIRES: IsRefType(output_dtype(output_index)).
void forward_ref_input_to_ref_output(int input_index, int output_index);
// Returns true when an alias to input[input_index], reshaped to output_shape,
// which is safe to use for in-place computation was written to *output.
// Returns false if input[input_index] has a refcount greater than one, or if
// its type does not match the expected output type of output[output_index],
// or the number of elements in input[input_index] does not equal the number
// of elements in output_shape.
bool forward_input_to_output_with_shape(int input_index, int output_index,
const TensorShape& output_shape,
Tensor** output) TF_MUST_USE_RESULT;
Status forward_input_to_output_with_shape(StringPiece input_name,
StringPiece output_name,
const TensorShape& output_shape,
Tensor** output) TF_MUST_USE_RESULT;
// Returns a pointer to a Tensor aliasing the underlying buffer backing
// input[input_index] iff
// * input[input_index] is not a ref,
// * the data type, shape, memory type, and allocator attributes of
// input[input_index] are compatible with those given in dtype, shape,
// memory_type, and attr,
// * refcount on the underlying buffer is one.
// * Either there is no forwarding reservation for either input_index
// or output_index or the specified input is reserved for the specified
// output. More precisely:
//
// These cases mean neither input nor output has a reservation:
// forward_from_array = nullptr
// OR (input_index is not in forward_from_array AND
// (output_index == kNoReservation OR
// forward_from_array[output_index] == kNoReservation))
//
// This case means that input_index is reserved for output_index:
// forward_from_array[output_index] == input_index
//
// This case means the output is reserved to always be allocated,
// never assigned a forwarded input:
// forward_from_array[output_index] == kNeverForward
//
// Otherwise returns nullptr.
// NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
// forwarding is only safe if there are no reads via __ldg() after writes
// to the same address.
std::unique_ptr<Tensor> forward_input(
int input_index, int output_index, DataType output_dtype,
const TensorShape& output_shape, MemoryType output_memory_type,
const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
// Tries to forward one of the inputs given in input_indices to
// output[output_index]. If none of the given inputs can be forwarded, calls
// allocate_output() to allocate a new output buffer. The index of the
// forwarded input will be assign to output argument forwarded_input (if it's
// not nullptr). If no inputs are forwarded, forwarded_input will be assigned
// -1.
Status forward_input_or_allocate_output(
gtl::ArraySlice<int> candidate_input_indices, int output_index,
const TensorShape& output_shape, Tensor** output,
int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
Status forward_input_or_allocate_output(
gtl::ArraySlice<StringPiece> candidate_input_names,
StringPiece output_name, const TensorShape& output_shape,
Tensor** output) TF_MUST_USE_RESULT;
// Tries to reuse one of the inputs given in input_indices as a temporary.
// If none of the given inputs can be forwarded, calls
// allocate_temp() to allocate a new temporary buffer.
Status forward_input_or_allocate_temp(
gtl::ArraySlice<int> candidate_input_indices, DataType type,
const TensorShape& shape, const AllocatorAttributes& allocator_attr,
Tensor* out_temp) TF_MUST_USE_RESULT;
Status forward_input_or_allocate_temp(
gtl::ArraySlice<int> candidate_input_indices, DataType type,
const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
AllocatorAttributes(), out_temp);
}
// Output
// Returns the named list-valued output in "list", as defined in the OpDef.
// If the named output is not list-valued, returns a one-element list.
Status output_list(StringPiece name, OpOutputList* list);
// If output_required(index) returns true, the OpKernel's Compute() method
// should call allocate_output(index, ...), set_output(index, ...),
// set_output_ref(index, ...), or set the status to a non-ok value.
// If it returns false, it may output, but is not required to do so.
bool output_required(int index) const {
return !params_->outputs_required_array ||
params_->outputs_required_array[index];
}
// If output_expects_forwarding returns true, the OpKernel's Compute() method
// should not allocate the output with allocate_output but instead needs to
// use forward_input.
bool output_expects_forwarding(int index) const {
return params_->forward_from_array != nullptr &&
params_->forward_from_array[index] >= 0;
}
// Allocation of tensors during kernel execution inside the Compute
// method:
//
// There are two methods to allocate Tensors when an Op kernel
// executes.
//
// 1) allocate_output. This should be used to allocate any tensor
// that is going to be used as an output from the Op at the end of
// the current execution. The caller indicates which output the
// Tensor will be assigned to, and the call returns the
// newly-allocated Tensor. The Tensor can subsequently be assigned
// to during kernel execution, and will be used as the designated
// output when the kernel execution completes.
//
// 2) allocate_temp. This should be used to allocate any scratch
// storage that is needed while the kernel is executing, and will
// not be retained by the Op.
//
// In some cases a Tensor needs to be used as an output even though
// it was previously allocated elsewhere. The Tensor may have been
// passed as an input, or stored in a Tensor during a
// previous kernel execution, or allocated earlier in the kernel
// execution at a time when it was not known which output it would
// be assigned to. In this case the kernel can use set_output or
// set_output_ref to indicate that the tensor should be used as the
// designated output. It is legal to use any previously-allocated
// Tensor as an argument to set_output or set_output_ref, including
// Tensors allocated via allocate_temp. There may be a performance
// penalty to using a Tensor that was not allocated using
// allocate_output. This is because allocate_output uses the
// AllocatorAttributes stored in output_attr_array for the
// designated output. In some cases, using the wrong attributes may
// cause an extra copy of the Tensor's buffer.
// Allocates output for the specified output index with shape.
// OpKernelContext retains ownership of the returned pointer. See
// comment above.
//
// If memory allocation fails, returns an error status.
//
// REQUIRES: !IsRefType(expected_output_dtype(index))
Status allocate_output(int index, const TensorShape& shape,
Tensor** tensor) TF_MUST_USE_RESULT;
Status allocate_output(StringPiece name, const TensorShape& shape,
Tensor** tensor) TF_MUST_USE_RESULT;
// The following methods use the supplied attributes instead of
// those in output_attr_array. The caller is responsible for
// ensuring that the attributes are "compatible" with the
// output_attr_array, e.g. the tensor is allocated on the correct
// device. See comment above.
Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
AllocatorAttributes attr) TF_MUST_USE_RESULT;
Status allocate_output(StringPiece name, const TensorShape& shape,
Tensor** tensor,
AllocatorAttributes attr) TF_MUST_USE_RESULT;
// Allocates a temporary Tensor of the specified type and
// shape. Devices such as GPUs that enqueue Ops for lazy execution
// may retain references to the temporary tensors after the Op's
// Compute method has run. See comment above.
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp, AllocatorAttributes allocator_attr,
const AllocationAttributes& allocation_attr);
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp, AllocatorAttributes allocator_attr) {
return allocate_temp(type, shape, out_temp, allocator_attr,
AllocationAttributes());
}
Status allocate_temp(DataType type, const TensorShape& shape,
Tensor* out_temp) {
return allocate_temp(type, shape, out_temp, AllocatorAttributes());
}
// Copies a tensor (allocated by the caller) to the specified output
// index. REQUIRES: !IsRefType(expected_output_dtype(index))
// REQUIRES: 'tensor' must have the same MemoryType as
// output_memory_types[index]. See comment above.
Status set_output(StringPiece name, const Tensor& tensor);
Status set_output(StringPiece name, Tensor&& tensor);
void set_output(int index, const Tensor& tensor);
void set_output(int index, Tensor&& tensor);
// To output a reference. Caller retains ownership of mu and tensor_for_ref,
// and they must outlive all uses within the step. See comment above.
// REQUIRES: IsRefType(expected_output_dtype(index))
Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
// Returns nullptr if allocate_output() or set_output() have not been called.
Status mutable_output(StringPiece name, Tensor** tensor);
// Return the DeviceContext that should be used for this Op.
//
// If using the templated function, the type must be a subclass
// of DeviceContext.
//
// Returns nullptr if the device did not provide one.
template <typename T>
T* op_device_context();
DeviceContext* op_device_context() {
DeviceContext* ret = params_->op_device_context;
if (ret == nullptr) {
auto* dev_info = device()->tensorflow_accelerator_device_info();
if (dev_info) ret = dev_info->default_context;
}
return ret;
}
AllocatorAttributes input_alloc_attr(int index) const {
if (params_->input_alloc_attrs == nullptr) {
return AllocatorAttributes();
} else {
DCHECK_GE(index, 0);
DCHECK_LT(index, params_->input_alloc_attrs->size());
return (*params_->input_alloc_attrs)[index];
}
}
AllocatorAttributes output_alloc_attr(int index) const {
return params_->output_attr_array[index];
}
gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
gtl::InlinedVector<WrappedAllocator, 4> retrieved;
if (tracking_state_) {
mutex_lock lock(tracking_state_->mu);
retrieved.swap(tracking_state_->wrapped_allocators);
}
return retrieved;
}
// Communication.
//
// An op kernel communicates with outside environment through
// Rendezvous Send() and Recv().
RendezvousInterface* rendezvous() const { return params_->rendezvous; }
CollectiveExecutor* collective_executor() const {
return params_->collective_executor;
}
// An op kernel can access the session state it belongs to.
SessionState* session_state() const { return params_->session_state; }
// Unique identifier of the session it belongs to. Can be empty.
std::string session_handle() const { return params_->session_handle; }
// Metadata about the session. Can be nullptr.
const SessionMetadata* session_metadata() const {
return params_->session_metadata;
}
// An op kernel can access the tensor store of the run it belongs to.
TensorStore* tensor_store() const { return params_->tensor_store; }
// Function call support.
//
// If this kernel invocation is within a function execution,
// call_frame() returns the call frame for the function call.
CallFrameInterface* call_frame() const { return params_->call_frame; }
// If not nullptr, the kernel invoke functions defined in the
// library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
FunctionLibraryRuntime* function_library() const {
return params_->function_library;
}
std::function<void(std::function<void()>)>* runner() const {
return params_->runner;
}
StepStatsCollectorInterface* stats_collector() const {
return params_->stats_collector;
}
// Shared resources accessible to this kernel.
ResourceMgr* resource_manager() const { return params_->resource_manager; }
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
return params_->slice_reader_cache;
}
// Execution.
//
// OpKernels can use these eigen devices to carry out their
// numerical computation.
const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
return *device()->eigen_cpu_device();
}
const Eigen::GpuDevice& eigen_gpu_device() const {
return params_->eigen_gpu_device->device();
}
template <typename EigenDeviceType>
const EigenDeviceType& eigen_device() const;
// Error handling.
// If expected_inputs == inputs() and expected_outputs == output_types(),
// returns OK, else returns INVALID_ARGUMENT with an error message.
// Recommended for Ops with dynamic signatures, where validation can only
// be performed at runtime.
Status MatchSignature(const DataTypeSlice expected_inputs,
const DataTypeSlice expected_outputs);
// An OpKernel should call SetStatus() if Compute() encounters an
// error.
void SetStatus(const Status& status);
const Status& status() const { return status_; }
// Cancellation.
//
// EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
// example of how to use this API.
CancellationManager* cancellation_manager() const {
return params_->cancellation_manager;
}
// Other accessors.
// For control flow.
FrameAndIter frame_iter() const { return params_->frame_iter; }
bool is_input_dead() const { return params_->is_input_dead; }
// May be used, e.g., to get GPU handles, etc.
// TODO(tucker): Add example usage.
DeviceBase* device() const { return params_->device; }
// Per-step container for use by white-listed internal ops.
ScopedStepContainer* step_container() const {
return params_->step_container;
}
// Access to distributed coordination service.
CoordinationServiceAgent* coordination_service_agent() const {
return params_->coordination_service_agent;
}
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);
void CtxFailureWithWarning(const Status& s);
void CtxFailure(const char* file, int line, const Status& s);
void CtxFailureWithWarning(const char* file, int line, const Status& s);
// Unrecommended functions: these are functions that have some
// current uses but are not recommended for use, and may go away at
// some future major version release.
//
// The following functions all have versions that return Status
// to capture error conditions, and are strongly preferred.
Tensor* mutable_output(int index);
mutex* input_ref_mutex(int index);
void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
TensorValue release_output(int index);
bool track_allocations() const { return params_->track_allocations; }
// Records temp memory allocation. Tensor object is recorded to identify the
// case where temp memory is used as output memory.
void record_temp_memory_allocation(int64_t size, const Tensor& t)
TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
// Returns recorded size of temporary memory;
int64_t temp_memory_allocated() const
TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
// Records persistent memory allocation, size can be negative indicating
// deallocation.
void record_persistent_memory_allocation(int64_t size, int64_t alloc_id = -1)
TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
// Returns recorded size and ids of persistent memory.
int64_t persistent_memory_allocated() const
TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
std::vector<int64_t> persistent_alloc_ids() const
TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
// Resets counters for temp and persistent memory and recorded ids.
void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
bool input_is_ref(int index) const;
void set_record_memory_consumption(bool v);
// Used by OpKernel implementations to track actively running deferred ops.
//
// A deferred op is one whose Compute method returns (or whose ComputeAsync
// method invokes the callback) when work is scheduled onto a device. At that
// point, we don't know when the work will actually complete (or if it has
// already completed) on the device. These functions allow the executor to
// track the status of deferred ops and act accordingly.
//
// Deferred OpKernel implementations must use these methods to get two
// functions. It then must call these two functions in pairs, before and after
// device execution, respectively.
TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
DCHECK(params_->op_kernel->is_deferred());
return params_->inc_num_deferred_ops_function
? params_->inc_num_deferred_ops_function
: []() {};
}
TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
DCHECK(params_->op_kernel->is_deferred());
return params_->dec_num_deferred_ops_function
? params_->dec_num_deferred_ops_function
: []() {};
}
Allocator* get_allocator(AllocatorAttributes attr);
private:
bool record_memory_consumption_ = false;
// Internal common method used when allocating tensor memory
Status allocate_tensor(DataType type, const TensorShape& shape,
Tensor* out_tensor,
AllocatorAttributes allocator_attr) {
return allocate_tensor(type, shape, out_tensor, allocator_attr,
AllocationAttributes());
}
Status allocate_tensor(DataType type, const TensorShape& shape,
Tensor* out_tensor, AllocatorAttributes allocator_attr,
const AllocationAttributes& allocation_attr);
// Helpers for `set_output()`.
// Returns `true` if the tensor was copied into an allocated output.
bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
void maybe_track_allocations_for_set_output(const Tensor& tensor);
Status get_input_index(StringPiece name, int* out_index) const;
Status get_output_index(StringPiece name, int* out_index) const;
// Initialize the allocated_scope_ids_ set the first time this method is
// called.
void maybe_initialize_scope_id_set();
Status status_;
friend class CollectiveExecutor; // for access to params_
Params* params_; // not owned
gtl::InlinedVector<TensorValue, 4> outputs_;
// Keep track of calls to ScopedAllocator.
// TODO(ayushd): change to absl::flat_hash_set.
std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
// The following data members are only used when allocation tracking is
// enabled, memory consumption is being recorded, or tensor access is being
// recorded.
struct TrackingState {
mutable mutex mu;
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
TF_GUARDED_BY(mu);
mutable mutex stats_mu;
int64_t temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
int64_t persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
gtl::InlinedVector<std::pair<const void*, int64_t>, 2>
temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu);
gtl::InlinedVector<int64_t, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu);
};
std::unique_ptr<TrackingState> tracking_state_;
// For access to `params_->op_kernel`.
friend void CheckNotInComputeAsync(OpKernelContext* ctx,
const char* correct_macro_name);
TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
};
kernel实例化
运行时调用如下方法创建Kernel
// Instantiate an OpKernel that has been registered. Returns nullptr
// if no operation for that type of device / input signature combination
// (and a NOT_FOUND *status), or there is an error in construction (and
// an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
// of the returned pointer.
// EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
DeviceBase* device,
Allocator* allocator,
const NodeDef& node_def,
int graph_def_version, Status* status);
std::unique_ptr<OpKernel> CreateOpKernel(
DeviceType device_type, DeviceBase* device, Allocator* allocator,
const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
Status* status);
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib,
const std::shared_ptr<const NodeProperties>& props,
int graph_def_version, OpKernel** kernel);
Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
Allocator* allocator, FunctionLibraryRuntime* flib,
ResourceMgr* resource_mgr,
const std::shared_ptr<const NodeProperties>& props,
int graph_def_version, OpKernel** kernel);
Kernel注册
官方给的例子
// Register your OpKernel by specifying the Op's name, the device the
// kernel runs on, any type attr constraints for this kernel, any
// host-memory args, and the class to instantiate. Examples:
//
// // A kernel that supports all types.
// REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
//
// // The following are equivalent ways of specifying that the kernel only
// // works if the "T" type attr is set to DT_FLOAT.
// REGISTER_KERNEL_BUILDER(
// Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
// SubOp<float>);
// // (You would then repeat this for every type supported by "Sub".)
//
// // This form allows you to specify a list of types as the constraint.
// REGISTER_KERNEL_BUILDER(Name("Sub")
// .Device(DEVICE_CPU)
// .TypeConstraint("T", {DT_FLOAT}),
// SubOp<float>);
//
// // A kernel that expects one of the input tensors in host memory.
// REGISTER_KERNEL_BUILDER(
// Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
//
// See kernel_def_builder for details.
Kernel注册同样使用了宏,工厂等
REGISTER_KERNEL_BUILDER流程分析
REGISTER_KERNEL_BUILDER 调用了REGISTER_KERNEL_BUILDER_IMPL 调用了TF_EXTRACT_KERNEL_NAME 调用了TF_EXTRACT_KERNEL_NAME_IMPL 调用了REGISTER_KERNEL_BUILDER_IMPL_2 调用了TF_NEW_ID_FOR_INIT
调用了REGISTER_KERNEL_BUILDER_IMPL_3
// REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
// TODO(dodgen): There are some uses of this macro inside functions, where
// kernel_builder refers to (non-const) locals (they should be fixed). To
// accommodate those, kernel_builder.Build() appears as an argument to an
// immediately-called lambda (not in the lambda itself).
#define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr, \
is_system_kernel, ...) \
static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr \
TF_ATTRIBUTE_UNUSED = \
TF_INIT_ON_STARTUP_IF(is_system_kernel || \
(SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
SHOULD_REGISTER_OP(op_name))) \
<< ([](::tensorflow::KernelDef const* kernel_def) { \
也就是到这里了,使用kernel_factory来注册了一个lambda函数 \
::tensorflow::kernel_factory::OpKernelRegistrar registrar( \
kernel_def, #__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { \
return new __VA_ARGS__(context); 这里就是在new ZeroOut \
}); \
(void)registrar; \
return ::tensorflow::InitOnStartupMarker{}; \
})(kernel_builder_expr.Build()); //这里的kernel_builder_expr就是KernelDefBuilder,其实就是Name("ZeroOut").Device(DEVICE_CPU).Build(); 而且这里是对lambda函数的调用,所以会立即进入函数内
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);这个定义中,Name实际上是KernelDefBuilder. Device就是KernelDefBuilder::Device.
REGISTER_KERNEL_BUILDER( KernelDefBuilder对象, ZeroOut这个类)。
OpkernelRegistrar的构建函数里最终调用到这个GlobalKernelRegistry的Reigster
void* GlobalKernelRegistry() {
static KernelRegistry* global_kernel_registry = []() {
KernelRegistry* registry = new KernelRegistry;
OpRegistry::Global()->RegisterValidator(ValidateKernelRegistrations);
return registry;
}();
return global_kernel_registry;
}
struct KernelRegistry {
mutex mu;
std::unordered_multimap<string, KernelRegistration> registry //就是放在这个map里了
TF_GUARDED_BY(mu);
};
// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
namespace register_kernel {
class Name : public KernelDefBuilder {
public:
explicit Name(const char* op) : KernelDefBuilder(op) {}
};
} // namespace register_kernel
// Kernel registration appears as:
// REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl)
// We'd like to have "OpName" as a constant-expression, without requiring that
// of the overall KernelDefBuilder expression (beginning with the
// register_kernel::Name constructor above).
//
// So, we pull the "OpName" part to a separate macro-level argument. This
// involves treating Name("OpName") as a macro call, via token-pasting (e.g.
// M_## => M_Name("OpName")), and having it expand to '"OpName",
// Name("OpName")' which is then usable as two arguments.
#define TF_EXTRACT_KERNEL_NAME_Name(name_str) \
name_str, ::tensorflow::register_kernel::Name(name_str)
#define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__)
#define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...) \
TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \
__VA_ARGS__)
// REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
// TODO(dodgen): There are some uses of this macro inside functions, where
// kernel_builder refers to (non-const) locals (they should be fixed). To
// accommodate those, kernel_builder.Build() appears as an argument to an
// immediately-called lambda (not in the lambda itself).
#define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr, \
is_system_kernel, ...) \
static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr \
TF_ATTRIBUTE_UNUSED = \
TF_INIT_ON_STARTUP_IF(is_system_kernel || \
(SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
SHOULD_REGISTER_OP(op_name))) \
<< ([](::tensorflow::KernelDef const* kernel_def) { \
::tensorflow::kernel_factory::OpKernelRegistrar registrar( \
kernel_def, #__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { \
return new __VA_ARGS__(context); \
}); \
(void)registrar; \
return ::tensorflow::InitOnStartupMarker{}; \
})(kernel_builder_expr.Build());
// REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name,
// kernel_builder_expr.
#define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \
is_system_kernel, ...) \
TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name, \
kernel_builder_expr, is_system_kernel, __VA_ARGS__)
// REGISTER_KERNEL_BUILDER, but with is_system_kernel bound.
#define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \
TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder, \
is_system_kernel, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
TF_ATTRIBUTE_ANNOTATE("tf:kernel") \
REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__)
// The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
// `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
// unconditionally even when selective registration is used.
#define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
TF_ATTRIBUTE_ANNOTATE("tf:kernel") \
TF_ATTRIBUTE_ANNOTATE("tf:kernel:system") \
REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__)
// Checks whether a given kernel is registered on device_type.
bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
// If node of node_name, experimental_debug_info, node_op, node_device and
// node_attrs has a corresponding kernel registered on device_type, returns OK
// and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.
Status FindKernelDef(
const DeviceType& device_type, StringPiece node_name,
bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
const KernelDef** def, std::string* kernel_class_name);
// If node_def has a corresponding kernel registered on device_type,
// returns OK and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
const KernelDef** def, std::string* kernel_class_name);
// Writes a list of all registered kernels to LOG(INFO), to help users debug
// missing kernel errors.
void LogAllRegisteredKernels();
// Gets a list of all registered kernels.
KernelList GetAllRegisteredKernels();
// Gets a list of all registered kernels for which predicate returns true
KernelList GetFilteredRegisteredKernels(
const std::function<bool(const KernelDef&)>& predicate);
// Gets a list of all registered kernels for a given op
KernelList GetRegisteredKernelsForOp(StringPiece op_name);
namespace kernel_factory {
// OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
// them. You register factories with the TensorFlow core by constructing an
// OpKernelRegistrar and passing the factory as a constructor parameter.
class OpKernelFactory {
public:
virtual OpKernel* Create(OpKernelConstruction* context) = 0;
virtual ~OpKernelFactory() = default;
};
class OpKernelRegistrar {
public:
// Registers the given kernel factory with TensorFlow. TF will call the
// factory Create() method when it determines that a kernel matching the given
// KernelDef is required.
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
std::unique_ptr<OpKernelFactory> factory) {
InitInternal(kernel_def, kernel_class_name, std::move(factory));
}
// Registers the given factory function with TensorFlow. This is equivalent
// to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
OpKernel* (*create_fn)(OpKernelConstruction*)) {
InitInternal(kernel_def, kernel_class_name,
absl::make_unique<PtrOpKernelFactory>(create_fn));
}
private:
struct PtrOpKernelFactory : public OpKernelFactory {
explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
: create_func_(create_func) {}
OpKernel* Create(OpKernelConstruction* context) override;
OpKernel* (*create_func_)(OpKernelConstruction*);
};
void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
std::unique_ptr<OpKernelFactory> factory);
};
} // namespace kernel_factory
从动态库加载kernel
tensorflow/core/framework/op_kernel.cc
加载目录:tensorflow/core/kernels目录中的所有so。实际上使用了Env->LoadDynamicLibrary
void LoadDynamicKernelsInternal() {
Env* env = Env::Default();
// Override to allow loading unsafe packages for development.
// DO NOT USE UNLESS YOU KNOW WHAT ABI ISSUES YOU CAN ENCOUNTER.
char* _abi_check_env_var = getenv("TF_REALLY_LOAD_UNSAFE_PACKAGES");
bool override_abi_check = false;
if (_abi_check_env_var != nullptr) {
override_abi_check = strcmp(_abi_check_env_var, "1") == 0;
}
string bazel_kernel_dir =
io::JoinPath(env->GetRunfilesDir(), "tensorflow", "core", "kernels");
std::vector<string> files;
Status s_kernel_dir = env->GetChildren(bazel_kernel_dir, &files);
if (s_kernel_dir.ok()) {
string dll_spec = io::JoinPath(bazel_kernel_dir, kKernelLibPattern);
for (const auto& file : files) {
string fullpath = io::JoinPath(bazel_kernel_dir, file);
if (env->MatchPath(fullpath, dll_spec)) {
Status s = IsProbablySafeToLoad(fullpath);
if (!s.ok() && override_abi_check) {
LOG(WARNING) << "Loading UNSAFE library " << fullpath
<< " because ABI check override is set: "
<< s.error_message();
}
if (s.ok() || override_abi_check) {
// TODO(gunan): Store the handles to the opened files.
void* unused_filehandle;
TF_CHECK_OK(
env->LoadDynamicLibrary(fullpath.c_str(), &unused_filehandle));
} else {
LOG(WARNING) << "Not loading plugin library " << fullpath << ": "
<< s.error_message();
}
}
}
}
}
// Mechanism for loading existing kernel libraries.
void LoadDynamicKernels() {
// TODO(gunan): As more features are available, add intelligent kernel
// selection, and dropping unsuitable kernel logic here.
static absl::once_flag dll_loader_flag;
absl::call_once(dll_loader_flag, LoadDynamicKernelsInternal);
}
Kernel从context中获取输入,分配输出时返回错误
tensorflow/core/framework/op_requires.h中定义了大量的宏,帮助我们实现这些功能。这些宏能根据需要返回错误
#define OP_REQUIRES_OK(CTX, ...) \
do { \
::tensorflow::Status _s(__VA_ARGS__); \
if (!TF_PREDICT_TRUE(_s.ok())) { \
CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
(CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
return; \
} \
} while (0)