operator(op)是tensorflow扩展功能的的方式。OP分为声明和定义。声明叫op,实现叫kernel.一个声明可以有多个实现。或者说在不同设备上的不同实现。OP需要注册。
时刻注意,OP只是一个声明。如同C++的函数声明。并不涉及这些OP如何实现。比如可以声明一个OP叫Add,其功能是可以做两个数的加法int Add(int a, int b); 而这个声明用一个proto message表示就是message OpDef。而图就是多个OP的输入输出首尾相接组成的有向无环图,这个图实际上表示了函数的调用关系。
OP注册中心接口
只提供了根据名字查找OP的接口。tensorflow/core/framework/op.h
class OpRegistryInterface {
public:
virtual ~OpRegistryInterface();
// Returns an error status and sets *op_reg_data to nullptr if no OpDef is
// registered under that name, otherwise returns the registered OpDef.
// Caller must not delete the returned pointer.
virtual Status LookUp(const std::string& op_type_name,
const OpRegistrationData** op_reg_data) const = 0;
// Shorthand for calling LookUp to get the OpDef.
Status LookUpOpDef(const std::string& op_type_name,
const OpDef** op_def) const;
};
Status DefaultValidator(const OpRegistryInterface& op_registry) {
LOG(WARNING) << "No kernel validator registered with OpRegistry.";
return Status::OK();
}
// OpRegistry -----------------------------------------------------------------
OpRegistryInterface::~OpRegistryInterface() {}
Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
const OpDef** op_def) const {
*op_def = nullptr;
const OpRegistrationData* op_reg_data = nullptr;
TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
*op_def = &op_reg_data->op_def;
return Status::OK();
}
//实际的一个实现
class OpRegistry : public OpRegistryInterface {
public:
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
OpRegistry();
~OpRegistry() override;
void Register(const OpRegistrationDataFactory& op_data_factory);
Status LookUp(const std::string& op_type_name,
const OpRegistrationData** op_reg_data) const override;
// Returns OpRegistrationData* of registered op type, else returns nullptr.
const OpRegistrationData* LookUp(const std::string& op_type_name) const;
// Fills *ops with all registered OpDefs (except those with names
// starting with '_' if include_internal == false) sorted in
// ascending alphabetical order.
void Export(bool include_internal, OpList* ops) const;
// Returns ASCII-format OpList for all registered OpDefs (except
// those with names starting with '_' if include_internal == false).
std::string DebugString(bool include_internal) const;
// A singleton available at startup.
static OpRegistry* Global();
// Get all registered ops.
void GetRegisteredOps(std::vector<OpDef>* op_defs);
// Get all `OpRegistrationData`s.
void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
// Registers a function that validates op registry.
void RegisterValidator(
std::function<Status(const OpRegistryInterface&)> validator) {
op_registry_validator_ = std::move(validator);
}
// Watcher, a function object.
// The watcher, if set by SetWatcher(), is called every time an op is
// registered via the Register function. The watcher is passed the Status
// obtained from building and adding the OpDef to the registry, and the OpDef
// itself if it was successfully built. A watcher returns a Status which is in
// turn returned as the final registration status.
typedef std::function<Status(const Status&, const OpDef&)> Watcher;
// An OpRegistry object has only one watcher. This interface is not thread
// safe, as different clients are free to set the watcher any time.
// Clients are expected to atomically perform the following sequence of
// operations :
// SetWatcher(a_watcher);
// Register some ops;
// op_registry->ProcessRegistrations();
// SetWatcher(nullptr);
// Returns a non-OK status if a non-null watcher is over-written by another
// non-null watcher.
Status SetWatcher(const Watcher& watcher);
// Process the current list of deferred registrations. Note that calls to
// Export, LookUp and DebugString would also implicitly process the deferred
// registrations. Returns the status of the first failed op registration or
// Status::OK() otherwise.
Status ProcessRegistrations() const;
// Defer the registrations until a later call to a function that processes
// deferred registrations are made. Normally, registrations that happen after
// calls to Export, LookUp, ProcessRegistrations and DebugString are processed
// immediately. Call this to defer future registrations.
void DeferRegistrations();
// Clear the registrations that have been deferred.
void ClearDeferredRegistrations();
private:
// Ensures that all the functions in deferred_ get called, their OpDef's
// registered, and returns with deferred_ empty. Returns true the first
// time it is called. Prints a fatal log if any op registration fails.
bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Calls the functions in deferred_ and registers their OpDef's
// It returns the Status of the first failed op registration or Status::OK()
// otherwise.
Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Add 'def' to the registry with additional data 'data'. On failure, or if
// there is already an OpDef with that name registered, returns a non-okay
// status.
Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const;
mutable mutex mu_;
// Functions in deferred_ may only be called with mu_ held.
mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_);
// Values are owned.
mutable std::unordered_map<string, const OpRegistrationData*> registry_
TF_GUARDED_BY(mu_); //op就是注册在这里了
mutable bool initialized_ TF_GUARDED_BY(mu_);
// Registry watcher.
mutable Watcher watcher_ TF_GUARDED_BY(mu_);
std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
};
看这几个接口很简单,但是其参数OpDef, OpRegistrationData很复杂。
OpDef
一个op有多个输入参数,和多个输入属性,还有多个输出参数,多个控制输出。它们都是Tensor。
输入属性的值在构图时已经确定不变了。而输入参数是执行图时变化数据。
class OpDef 是定义在proto中的。tensorflow/core/framework/op_def.proto
这个proto就声明了个OP.实际上就是把输入输出参数,OP名字等等元信息保存下来。
message OpDef {
string name = 1; // op名字
message ArgDef { // op输入输出参数
string name = 1;
string description = 2;
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
string type_list_attr = 6;
repeated ResourceHandleProto.DtypeAndShape handle_data = 7;
bool is_ref = 16;
FullTypeDef experimental_full_type = 17;
}
repeated ArgDef input_arg = 2;
repeated ArgDef output_arg = 3;
repeated string control_output = 20; //控制参数
message AttrDef { //op属性,构图时已经确定不变
string name = 1;
string type = 2;
AttrValue default_value = 3;
string description = 4;
bool has_minimum = 5;
int64 minimum = 6;
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4; //属性
OpDeprecation deprecation = 8;
string summary = 5;
string description = 6;
bool is_commutative = 18;
bool is_aggregate = 16; // for things like add
bool is_stateful = 17; // for things like variables, queue
bool allows_uninitialized_input = 19; // for Assign, etc.
bool is_distributed_communication = 21;
}
message OpDeprecation {
int32 version = 1;
string explanation = 2;
}
message OpList { //一组op
repeated OpDef op = 1;
}
OpDefBuilder来生成OP
Builder可以通过特定语法格式的字符串来添加 输入参数,输出参数等。添加完成后调用Finalize(OpRegistrationData* op_reg_data)生成了OpRegistrationData. OpRegistrationData有OpDef
tensorflow/core/framework/op_def_builder.h
// Builder class passed to the REGISTER_OP() macro.
class OpDefBuilder {
public:
explicit OpDefBuilder(std::string op_name);
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
// format "<name>:<type>" or "<name>:<type>=<default>"
// where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]*
// (by convention only using capital letters for attrs that can be inferred)
// <type> can be:
// "string", "int", "float", "bool", "type", "shape", or "tensor"
// "numbertype", "realnumbertype", "quantizedtype"
// (meaning "type" with a restriction on valid values)
// "{int32,int64}" or {realnumbertype,quantizedtype,string}"
// (meaning "type" with a restriction containing unions of value types)
// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
// (meaning "string" with a restriction on valid values)
// "list(string)", ..., "list(tensor)", "list(numbertype)", ...
// (meaning lists of the above types)
// "int >= 2" (meaning "int" with a restriction on valid values)
// "list(string) >= 2", "list(int) >= 2"
// (meaning "list(string)" / "list(int)" with length at least 2)
// <default>, if included, should use the Proto text format
// of <type>. For lists use [a, b, c] format.
//
// Note that any attr specifying the length of an input or output will
// get a default minimum of 1 unless the >= # syntax is used.
//
// TODO(josh11b): Perhaps support restrictions and defaults as optional
// extra arguments to Attr() instead of encoding them in the spec string.
// TODO(josh11b): Would like to have better dtype handling for tensor attrs:
// * Ability to say the type of an input/output matches the type of
// the tensor.
// * Ability to restrict the type of the tensor like the existing
// restrictions for type attrs.
// Perhaps by linking the type of the tensor to a type attr?
OpDefBuilder& Attr(std::string spec);
// Adds an input or output to this OpDefBuilder (and returns *this).
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
// where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
// * For a single tensor: <type>
// * For a sequence of tensors with the same type: <number>*<type>
// * For a sequence of tensors with different types: <type-list>
// Where:
// <type> is either one of "float", "int32", "string", ...
// or the name of an attr (see above) with type "type".
// <number> is the name of an attr with type "int".
// <type-list> is the name of an attr with type "list(type)".
// TODO(josh11b): Indicate Ref() via an optional argument instead of
// in the spec?
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
// handling?
OpDefBuilder& Input(std::string spec);
OpDefBuilder& Output(std::string spec);
// Turns on the indicated boolean flag in this OpDefBuilder (and
// returns *this).
OpDefBuilder& SetIsCommutative();
OpDefBuilder& SetIsAggregate();
OpDefBuilder& SetIsStateful();
OpDefBuilder& SetAllowsUninitializedInput();
OpDefBuilder& SetIsDistributedCommunication();
// Deprecate the op at a certain GraphDef version.
OpDefBuilder& Deprecated(int version, std::string explanation);
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
// <1-line summary>
// <rest of the description>
// <name>: <description of name>
// <name>: <description of name>
// <if long, indent the description on subsequent lines>
// Where <name> is the name of an attr, input, or output. Please
// wrap docs at 72 columns so that it may be indented in the
// generated output. For tensor inputs or outputs (not attrs), you
// may start the description with an "=" (like name:= <description>)
// to suppress the automatically-generated type documentation in
// generated output.
OpDefBuilder& Doc(std::string text);
// Sets the function to be used as type constructor.
// See OpRegistrationData::type_ctor.
OpDefBuilder& SetTypeConstructor(OpTypeConstructor c);
// Sets the function to be used for forward type inference.
// See OpRegistrationData::fwd_type_fn.
OpDefBuilder& SetForwardTypeFn(ForwardTypeInferenceFn f);
// Sets the shape function to be used for shape inference.
//
// Note that currently (October 2016), python code still requires a
// RegisterShape call to invoke this; see call_cpp_shape_fn in
// python/framework/common_shapes.py
OpDefBuilder& SetShapeFn(OpShapeInferenceFn fn);
// Allows the `<type>` in calls to `Attr()` to be "any".
// This is used by PythonAPIWrapper for pass-through parameters.
OpDefBuilder& AllowAttrTypeAny();
// Sets op_reg_data->op_def to the requested OpDef and
// op_reg_data->shape_inference_fn to the requested shape inference function,
// or returns an error.
// Must be called after all of the above methods.
//
// Note that OpDefBuilder only reports parsing errors. You should also
// call ValidateOpDef() to detect other problems.
Status Finalize(OpRegistrationData* op_reg_data) const;
private:
friend class FunctionDefHelper;
// Adds control output to this OpDefBuilder (and returns *this).
// The <name> must be a valid node name (matches regexp
// [a-zA-Z][a-zA-Z0-9_]*). Named control output can only exist for functions.
OpDefBuilder& ControlOutput(std::string name);
OpDef* op_def() { return &op_reg_data_.op_def; }
OpRegistrationData op_reg_data_;
std::vector<string> attrs_;
std::vector<string> inputs_;
std::vector<string> outputs_;
std::vector<string> control_outputs_;
std::string doc_;
std::vector<string> errors_;
bool allow_attr_type_any_ = false;
};
OpRegistryData
struct OpRegistrationData {
public:
OpRegistrationData() {}
OpRegistrationData(const OpDef& def) : op_def(def) {}
OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn,
bool is_function = false)
: op_def(def), shape_inference_fn(fn), is_function_op(is_function) {}
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
OpTypeConstructor type_ctor;
ForwardTypeInferenceFn fwd_type_fn;
bool is_function_op = false;
};
Op注册原理
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
REGISTER_OP宏,实际上定义了如下的OpDefBuilderWrapper的对象。
后续调用的.Input, .Output,等都是对此对象中的Input, Output的方法的调用。而Input里实现上转而调用了OpDefBuilder的Input。
namespace register_op {
class OpDefBuilderWrapper {
public:
explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
OpDefBuilderWrapper& Attr(std::string spec) {
builder_.Attr(std::move(spec));
return *this;
}
OpDefBuilderWrapper& Input(std::string spec) {
builder_.Input(std::move(spec));
return *this;
}
OpDefBuilderWrapper& Output(std::string spec) {
builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper& SetIsCommutative() {
builder_.SetIsCommutative();
return *this;
}
OpDefBuilderWrapper& SetIsAggregate() {
builder_.SetIsAggregate();
return *this;
}
OpDefBuilderWrapper& SetIsStateful() {
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper& SetDoNotOptimize() {
// We don't have a separate flag to disable optimizations such as constant
// folding and CSE so we reuse the stateful flag.
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper& SetAllowsUninitializedInput() {
builder_.SetAllowsUninitializedInput();
return *this;
}
OpDefBuilderWrapper& Deprecated(int version, std::string explanation) {
builder_.Deprecated(version, std::move(explanation));
return *this;
}
OpDefBuilderWrapper& Doc(std::string text) {
builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) {
builder_.SetShapeFn(std::move(fn));
return *this;
}
OpDefBuilderWrapper& SetIsDistributedCommunication() {
builder_.SetIsDistributedCommunication();
return *this;
}
OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) {
builder_.SetTypeConstructor(std::move(fn));
return *this;
}
OpDefBuilderWrapper& SetForwardTypeFn(ForwardTypeInferenceFn fn) {
builder_.SetForwardTypeFn(std::move(fn));
return *this;
}
const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
//下文中提到的InitOnStartupMarker 中调用了这个
InitOnStartupMarker operator()();
private:
mutable ::tensorflow::OpDefBuilder builder_;
};
} // namespace register_op
#define REGISTER_OP_IMPL(ctr, name, is_system_op) \
static ::tensorflow::InitOnStartupMarker const register_op##ctr \
TF_ATTRIBUTE_UNUSED = \
TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
<< ::tensorflow::register_op::OpDefBuilderWrapper(name)
#define REGISTER_OP(name) \
TF_ATTRIBUTE_ANNOTATE("tf:op") \
TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)
// The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
// that the op is registered unconditionally even when selective
// registration is used.
#define REGISTER_SYSTEM_OP(name) \
TF_ATTRIBUTE_ANNOTATE("tf:op") \
TF_ATTRIBUTE_ANNOTATE("tf:op:system") \
TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true)
} // namespace tensorflow
- REGISTER_OP这个宏调用TF_NEW_ID_FOR_INIT. 会使用__COUNTER__宏生成唯一ID.
- 调用REGISTER_OP_IMPLE时,参数ctr就是counter。
- REGISTER_OP_IMPLE所有定义了一个static变量。变量类型是 tensorlfow::InitOnStartUpMarker。变量名是register_op##ctr,实际上就是register_op0, register_op1, ....
- TF_INIT_ON_STARTUP_IF宏如果参数是false,则什么也不做,否则 调用后边的<< OpeDefBuilder。这个宏根相当于:!cond ? InitOnStartupMarker{} : (InitOnStartupMarker{} << f); f就是::tensorflow::register_op::OpDefBuilderWrapper(name)。因为InitOnStartUpmarker重载了operator<<。
- 在下图代码InitOpStartupMarker里调用了OpDefBuilderWrapper的Operator()方法。
struct InitOnStartupMarker {
constexpr InitOnStartupMarker operator<<(InitOnStartupMarker) const {
return *this;
}
template <typename T>
constexpr InitOnStartupMarker operator<<(T&& v) const {
return std::forward<T>(v)(); #相当于调用OpDefBuilderWrapper对像的operator()
}
};
#define TF_INIT_ON_STARTUP_IF(cond) \
(::std::integral_constant<bool, !(cond)>::value) \
? ::tensorflow::InitOnStartupMarker{} \
: ::tensorflow::InitOnStartupMarker {}
真正注册在这里:
InitOnStartupMarker OpDefBuilderWrapper::operator()() {
OpRegistry::Global()->Register(
[builder =
std::move(builder_)](OpRegistrationData* op_reg_data) -> Status {
return builder.Finalize(op_reg_data);
});
return {};
}
// static
OpRegistry* OpRegistry::Global() {
static OpRegistry* global_op_registry = new OpRegistry;
return global_op_registry;
}
void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
mutex_lock lock(mu_);
if (initialized_) {
TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
} else {
deferred_.push_back(op_data_factory);
}
}
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
最终把构建好的op,其实就是OpRegistrationData插入到map<op_name, OpRegistrationData*> OpRegistry::registry_中。
总结
op声明就是构建OpRegistrationData,其中需要添加输入输出,属性等等参数。为此OpDefBuilder来方便注册,可以一步步添加输入输出,最后调用个Finalize来生成OpRegistrationData。OP声明需要注册到OpRegistry中,通过调用宏REGISTER_OP来生成全局静态变量OpDefBuilderWrapper,在静态变量初始化时会把构建好的OpRegistrationData添加到OpRegistry中。