TFLite数据结构
lite/c/c_api_internal.h
// This file defines a C API for implementing operations in tflite.
// These operations can be defined using c++ but the interface between
// the interpreter and the operations are C. (operations的实现可以是c++,但interpreter和operation间的接口是C)
//
// Summary of abstractions
// TF_LITE_ENSURE - Self-sufficient error checking
// TfLiteStatus - Status reporting
// TfLiteIntArray - stores tensor shapes (dims),
// TfLiteContext - allows an op to access the tensors
// TfLiteTensor - tensor (a multidimensional array)
// TfLiteNode - a single node or operation
// TfLiteRegistration - the implementation of a conceptual operation.
TFLiteSensor
// An tensor in the interpreter system which is a wrapper around a buffer of
// data including a dimensionality (or NULL if not currently defined).
TFLiteNode
// A structure representing an instance of a node.
// This structure only exhibits the inputs, outputs and user defined data, not
// other features like the type.
TFLiteContext
TfLiteRegistration
其中包含了函数指针init, free, prepare, invoke and profiling_string; 表示builtin所用的op_code, 用户定制的name 和版本号
typedef struct _TfLiteRegistration {
// Initializes the op from serialized data.
// If a built-in op: (buffer, length来着flatbuffers?)
// `buffer` is the op's params data (TfLiteLSTMParams*).
// `length` is zero.
// If custom op:
// `buffer` is the op's `custom_options`.
// `length` is the size of the buffer.
//
// Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
// or an instance of a struct).
//
// The returned pointer will be stored with the node in the `user_data` field,
// accessible within prepare and invoke functions below.
// NOTE: if the data is already in the desired format, simply implement this
// function to return `nullptr` and implement the free function to be a no-op.
void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
// The pointer `buffer` is the data previously returned by an init invocation.
void (*free)(TfLiteContext* context, void* buffer);
// prepare is called when the inputs this node depends on have been resized.
// context->ResizeTensor() can be called to request output tensors to be
// resized.
// 当node的inputs resize时调用该函数
// Returns kTfLiteOk on success.
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
// Execute the node (should read node->inputs and output to node->outputs).
// Returns kTfLiteOk on success.
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
// profiling_string is called during summarization of profiling information
// in order to group executions together. Providing a value here will cause a
// given op to appear multiple times is the profiling report. This is
// particularly useful for custom ops that can perform significantly
// different calculations depending on their `user-data`.
const char* (*profiling_string)(const TfLiteContext* context,
const TfLiteNode* node);
// Builtin codes. If this kernel refers to a builtin this is the code
// of the builtin. This is so we can do marshaling to other frameworks like
// NN API.
// Note: It is the responsibility of the registration binder to set this
// properly.
int32_t builtin_code;
// Custom op name. If the op is a builtin, this will be null.
// Note: It is the responsibility of the registration binder to set this
// properly.
// WARNING: This is an experimental interface that is subject to change.
const char* custom_name;
// The version of the op.
// Note: It is the responsibility of the registration binder to set this
// properly.
int version;
} TfLiteRegistration;
注册oprators的相关文件register.cc
c++文件的名字和类名的关系和java不同,c++中类名和文件名没有强烈关系。
register.h文件中定义BuiltinOpResolver
namespace tflite {
namespace ops {
namespace builtin {
class BuiltinOpResolver : public MutableOpResolver {
public:
BuiltinOpResolver();
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const override;
const TfLiteRegistration* FindOp(const char* op, int version) const override;
};
} // namespace builtin
} // namespace ops
} // namespace tflite
register.cc的代码结构: OpResolve的注册
namespace tflite {
namespace ops {
namespace builtin {
TfLiteRegistration* Register_RELU();
1] Register_RELU()的实现,返回static表示的全局变量TfLiteRegistration
TfLiteRegistration* Register_SOFTMAX() {
static TfLiteRegistration r = {activations::Init, activations::Free,
activations::SoftmaxPrepare,
activations::SoftmaxEval};
return &r;
}
3] FindOp就是查找 全局变量builtins_ and custom_ops
const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op,
int version) const {
return MutableOpResolver::FindOp(op, version);
}
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
2] AddBuilitin的实现: 就是写入全局变量builtins_ and custom_ops
void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
const TfLiteRegistration* registration,
int min_version, int max_version) {
for (int version = min_version; version <= max_version; ++version) {
TfLiteRegistration new_registration = *registration;
new_registration.custom_name = nullptr;
new_registration.builtin_code = op;
new_registration.version = version;
auto op_key = std::make_pair(op, version);
builtins_[op_key] = new_registration;
}
}
}
}//builtin
}//ops
}//tflite
native: createInterpreter
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle) {
tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle);
auto resolver = ::tflite::CreateOpResolver();
std::unique_ptr<tflite::Interpreter> interpreter;
TfLiteStatus status = //由flatbuffermodel和OpResolver创建 InterpreterBuild
tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter);
return reinterpret_cast<jlong>(interpreter.release());
}
1] CreateOpResolver:builtin_ops_jni.cc
// The JNI code in nativeinterpreterwrapper_jni.cc expects a CreateOpResolver() function in
// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the
// builtin ops. For smaller binary sizes users should avoid linking this in, and
// should provide a custom make CreateOpResolver() instead.
//CreateOpResolver该函数实现了所有operator的注册方法,如果要裁剪请修改实现
std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(:
new tflite::ops::builtin::BuiltinOpResolver()); //请参考上面数据结构的分析
}
2] tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter);
InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
const OpResolver& op_resolver)
: model_(model.GetModel()),
op_resolver_(op_resolver),
error_reporter_(model.error_reporter()),
allocation_(model.allocation())
{}//这是InterpreterBuilder的构造函数,要和interpreter的构造函数区别开
后面的(&interpreter)莫名奇妙,是什么意思?其实这里重载了():
TfLiteStatus InterpreterBuilder::operator()(
std::unique_ptr<Interpreter>* interpreter) {
return operator()(interpreter, /*num_threads=*/-1);
}
TfLiteStatus InterpreterBuilder::operator()(
std::unique_ptr<Interpreter>* interpreter, int num_threads) {
}
modle文件怎样解析生成context等都是这个函数实现的
3] TfLiteStatus InterpreterBuilder::operator()(
//代码能找到这里,整个流程基本就通了,这里就是获得flatbuffers中的各种数据对 interpreter进行赋值,创建 interpreter
std::unique_ptr<Interpreter>* interpreter) {//unique_ptr对象管理Interpreter
if (!interpreter) {
error_reporter_->Report(
"Null output pointer passed to InterpreterBuilder.");
return kTfLiteError;
}
// Safe exit by deleting partially created interpreter, to reduce verbosity
// on error conditions. Use by return cleanup_on_error();
auto cleanup_and_error = [&interpreter]() {
interpreter->reset();
return kTfLiteError;
};
if (!model_) {
error_reporter_->Report("Null pointer passed in as model.");
return cleanup_and_error();
}
3.1 model
if (model_->version() != TFLITE_SCHEMA_VERSION) {
error_reporter_->Report(
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model_->version(), TFLITE_SCHEMA_VERSION);
return cleanup_and_error();
}
if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
error_reporter_->Report("Registration failed.\n");
return cleanup_and_error();
}
// Flatbuffer model schemas define a list of opcodes independent of the graph.
// We first map those to registrations. This reduces string lookups for custom
// ops since we only do it once per custom op rather than once per custom op
// invocation in the model graph.
// Construct interpreter with correct number of tensors and operators.
auto* subgraphs = model_->subgraphs();
auto* buffers = model_->buffers();
if (subgraphs->size() != 1) {
error_reporter_->Report("Only 1 subgraph is currently supported.\n");
return cleanup_and_error();
}
3.2 SubGraph
const tflite::SubGraph* subgraph = (*subgraphs)[0];
auto operators = subgraph->operators();
//这里的auto什么意思?operators也不是成员变量吧?auto关键字来要求编译器对变量name的类型进行了自动推导
auto tensors = subgraph->tensors();
if (!operators || !tensors || !buffers) {
error_reporter_->Report(
"Did not get operators, tensors, or buffers in input flat buffer.\n");
return cleanup_and_error();
}
3.3 new Interpreter(error_reporter)
interpreter->reset(new Interpreter(error_reporter_));//unique:interpreter接管了Interpreter对象
3.3.1 preserving pre-existing Tensor entries.
if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
return cleanup_and_error();
}
3.3.2 // Parse inputs/outputs
(**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
(**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
// Finally setup nodes and tensors
if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
return kTfLiteOk;
}
model结构体
model_->version();
BuildLocalIndexToRegistrationMapping(), model_->operator_codes();
model_->subgraphs();
model_->buffers();
SubGraph
const tflite::SubGraph* subgraph = (*subgraphs)[0];
operators = subgraph->operators();
tensors = subgraph->tensors();
new Interpreter //到这里才创建了Interpreter对象
interpreter->reset(new Interpreter(error_reporter_));
Interpreter::Interpreter(ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
context_.impl_ = static_cast<void*>(this);
context_.ResizeTensor = ResizeTensor;
context_.ReportError = ReportError;
context_.AddTensors = AddTensors;
context_.tensors = nullptr;
context_.tensors_size = 0;
context_.gemm_context = nullptr;
// Invalid to call these these except from TfLiteDelegate
context_.GetNodeAndRegistration = nullptr;
context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
context_.GetExecutionPlan = nullptr;
// Reserve some space for the tensors to avoid excessive resizing.
tensors_.reserve(kSlotsToReserve);
nodes_and_registration_.reserve(kSlotsToReserve);
next_execution_plan_index_to_prepare_ = 0;
UseNNAPI(false);
}
创建Interpreter后,赋值tensors/input/output/Op等
向SubGraph中添加 Tensors
(**interpreter).AddTensors(tensors->Length())
// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
// The value pointed to by `first_new_tensor_index` will be set to the
// index of the first new tensor if `first_new_tensor_index` is non-null.
TfLiteStatus AddTensors(int tensors_to_add,
int* first_new_tensor_index = nullptr);
TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
int* first_new_tensor_index) {
int base_index = tensors_.size();// 这里是 0 吗?
if (first_new_tensor_index) *first_new_tensor_index = base_index;
tensors_.resize(tensors_.size() + tensors_to_add);
for (int i = base_index; i < tensors_.size(); i++) {
memset(&tensors_[i], 0, sizeof(tensors_[i]));
}
context_.tensors = tensors_.data();
context_.tensors_size = tensors_.size();
return kTfLiteOk;
}
设置SubGrah的 input/output
(**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
(**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
TF_LITE_ENSURE_OK(&context_,
CheckTensorIndices("inputs", inputs.data(), inputs.size()));
inputs_ = std::move(inputs);
return kTfLiteOk;
}
TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
TF_LITE_ENSURE_OK(
&context_, CheckTensorIndices("outputs", outputs.data(), outputs.size()));
outputs_ = std::move(outputs);
return kTfLiteOk;
}
ParseNodes and ParseTensors
// Finally setup nodes and tensors
if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)