TFLite:代码分析(2):创建 interpreter

TFLite数据结构

lite/c/c_api_internal.h

// This file defines a C API for implementing operations in tflite.
// These operations can be defined using c++ but the interface between
// the interpreter and the operations are C. (operations的实现可以是c++,但interpreter和operation间的接口是C)
//
// Summary of abstractions                                                                                                               
// TF_LITE_ENSURE - Self-sufficient error checking
// TfLiteStatus - Status reporting
// TfLiteIntArray - stores tensor shapes (dims),
// TfLiteContext - allows an op to access the tensors
// TfLiteTensor - tensor (a multidimensional array)
// TfLiteNode - a single node or operation
// TfLiteRegistration - the implementation of a conceptual operation.

TFLiteSensor

// An tensor in the interpreter system which is a wrapper around a buffer of
// data including a dimensionality (or NULL if not currently defined).

TFLiteNode

// A structure representing an instance of a node.
// This structure only exhibits the inputs, outputs and user defined data, not
// other features like the type.

TFLiteContext

TfLiteRegistration

其中包含了函数指针init, free, prepare, invoke and profiling_string; 表示builtin所用的op_code, 用户定制的name 和版本号

typedef struct _TfLiteRegistration {
  // Initializes the op from serialized data.
  // If a built-in op: (buffer, length来着flatbuffers?)
  //   `buffer` is the op's params data (TfLiteLSTMParams*).
  //   `length` is zero.
  // If custom op:
  //   `buffer` is the op's `custom_options`.
  //   `length` is the size of the buffer.
  //
  // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
  // or an instance of a struct).
  //
  // The returned pointer will be stored with the node in the `user_data` field,
  // accessible within prepare and invoke functions below.
  // NOTE: if the data is already in the desired format, simply implement this
  // function to return `nullptr` and implement the free function to be a no-op.
  void* (*init)(TfLiteContext* context, const char* buffer, size_t length);

  // The pointer `buffer` is the data previously returned by an init invocation.
  void (*free)(TfLiteContext* context, void* buffer);

  // prepare is called when the inputs this node depends on have been resized.
  // context->ResizeTensor() can be called to request output tensors to be
  // resized.
  // 当node的inputs resize时调用该函数
  // Returns kTfLiteOk on success.
  TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);

  // Execute the node (should read node->inputs and output to node->outputs).
  // Returns kTfLiteOk on success.
  TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);

  // profiling_string is called during summarization of profiling information
  // in order to group executions together. Providing a value here will cause a
  // given op to appear multiple times is the profiling report. This is
  // particularly useful for custom ops that can perform significantly
  // different calculations depending on their `user-data`.
  const char* (*profiling_string)(const TfLiteContext* context,
                                  const TfLiteNode* node);

  // Builtin codes. If this kernel refers to a builtin this is the code
  // of the builtin. This is so we can do marshaling to other frameworks like
  // NN API.
  // Note: It is the responsibility of the registration binder to set this
  // properly.
  int32_t builtin_code;

  // Custom op name. If the op is a builtin, this will be null.
  // Note: It is the responsibility of the registration binder to set this
  // properly.
  // WARNING: This is an experimental interface that is subject to change.
  const char* custom_name;

  // The version of the op.
  // Note: It is the responsibility of the registration binder to set this
  // properly.
  int version;
} TfLiteRegistration;

注册oprators的相关文件register.cc

c++文件的名字和类名的关系和java不同,c++中类名和文件名没有强烈关系。

register.h文件中定义BuiltinOpResolver

namespace tflite {
namespace ops {
namespace builtin {

class BuiltinOpResolver : public MutableOpResolver {
 public:
  BuiltinOpResolver();

  const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 
                                   int version) const override;
  const TfLiteRegistration* FindOp(const char* op, int version) const override;
};

}  // namespace builtin
}  // namespace ops
}  // namespace tflite
 

register.cc的代码结构: OpResolve的注册

namespace tflite {
namespace ops {
  
namespace builtin {
  TfLiteRegistration* Register_RELU();

 1] Register_RELU()的实现,返回static表示的全局变量TfLiteRegistration

TfLiteRegistration* Register_SOFTMAX() {
    static TfLiteRegistration r = {activations::Init, activations::Free,
                                   activations::SoftmaxPrepare,
                                   activations::SoftmaxEval};
    return &r; 
}   

3] FindOp就是查找 全局变量builtins_ and custom_ops

  const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op, 
                                                      int version) const {
    return MutableOpResolver::FindOp(op, version);
  }

 

  BuiltinOpResolver::BuiltinOpResolver() {
    AddBuiltin(BuiltinOperator_RELU, Register_RELU());

    2] AddBuilitin的实现: 就是写入全局变量builtins_ and custom_ops

    void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, 
                                   const TfLiteRegistration* registration,
                                   int min_version, int max_version) {
  for (int version = min_version; version <= max_version; ++version) {
    TfLiteRegistration new_registration = *registration;
    new_registration.custom_name = nullptr;
    new_registration.builtin_code = op; 
    new_registration.version = version;
    auto op_key = std::make_pair(op, version);
    builtins_[op_key] = new_registration;
  }
}   
 
  }
}//builtin
}//ops
}//tflite

native: createInterpreter

JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
    JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle) {

  tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
  BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle);

  auto resolver = ::tflite::CreateOpResolver();

  std::unique_ptr<tflite::Interpreter> interpreter;
  TfLiteStatus status = //由flatbuffermodel和OpResolver创建 InterpreterBuild
      tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter);
  return reinterpret_cast<jlong>(interpreter.release());
}

1] CreateOpResolver:builtin_ops_jni.cc

// The JNI code in nativeinterpreterwrapper_jni.cc expects a CreateOpResolver() function in
// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the
// builtin ops.
For smaller binary sizes users should avoid linking this in, and
// should provide a custom make CreateOpResolver() instead.

//CreateOpResolver该函数实现了所有operator的注册方法,如果要裁剪请修改实现
std::unique_ptr<OpResolver> CreateOpResolver() {  // NOLINT
  return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(:
      new tflite::ops::builtin::BuiltinOpResolver()); //请参考上面数据结构的分析
}

 

2] tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter);

InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
                                       const OpResolver& op_resolver)
    : model_(model.GetModel()),
      op_resolver_(op_resolver),
      error_reporter_(model.error_reporter()),
      allocation_(model.allocation()) 
{}//这是InterpreterBuilder的构造函数,要和interpreter的构造函数区别开
后面的(&interpreter)莫名奇妙,是什么意思?其实这里重载了():

  TfLiteStatus InterpreterBuilder::operator()(
      std::unique_ptr<Interpreter>* interpreter) {
    return operator()(interpreter, /*num_threads=*/-1);
  }
  
  TfLiteStatus InterpreterBuilder::operator()(
      std::unique_ptr<Interpreter>* interpreter, int num_threads) {

}
modle文件怎样解析生成context等都是这个函数实现的

3] TfLiteStatus InterpreterBuilder::operator()(

   //代码能找到这里,整个流程基本就通了,这里就是获得flatbuffers中的各种数据对 interpreter进行赋值,创建 interpreter
    std::unique_ptr<Interpreter>* interpreter) {//unique_ptr对象管理Interpreter
  if (!interpreter) {
    error_reporter_->Report(
        "Null output pointer passed to InterpreterBuilder.");
    return kTfLiteError;
  }

  // Safe exit by deleting partially created interpreter, to reduce verbosity
  // on error conditions. Use by return cleanup_on_error();
  auto cleanup_and_error = [&interpreter]() {
    interpreter->reset();
    return kTfLiteError;
  };

  if (!model_) {
    error_reporter_->Report("Null pointer passed in as model.");
    return cleanup_and_error();
  }

  3.1 model


  if (model_->version() != TFLITE_SCHEMA_VERSION) {
    error_reporter_->Report(
        "Model provided is schema version %d not equal "
        "to supported version %d.\n",
        model_->version(), TFLITE_SCHEMA_VERSION);
    return cleanup_and_error();
  }
  

  if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
    error_reporter_->Report("Registration failed.\n");
    return cleanup_and_error();
  }

  // Flatbuffer model schemas define a list of opcodes independent of the graph.
  // We first map those to registrations. This reduces string lookups for custom
  // ops since we only do it once per custom op rather than once per custom op
  // invocation in the model graph.
  // Construct interpreter with correct number of tensors and operators.
  auto* subgraphs = model_->subgraphs();
  auto* buffers = model_->buffers();
  if (subgraphs->size() != 1) {
    error_reporter_->Report("Only 1 subgraph is currently supported.\n");
    return cleanup_and_error();
  }

  3.2 SubGraph
  const tflite::SubGraph* subgraph = (*subgraphs)[0];
  auto operators = subgraph->operators();

 //这里的auto什么意思?operators也不是成员变量吧?auto关键字来要求编译器对变量name的类型进行了自动推导
  auto tensors = subgraph->tensors();
  if (!operators || !tensors || !buffers) {
    error_reporter_->Report(
        "Did not get operators, tensors, or buffers in input flat buffer.\n");
    return cleanup_and_error();
  }

  3.3 new Interpreter(error_reporter)


  interpreter->reset(new Interpreter(error_reporter_));//unique:interpreter接管了Interpreter对象

  3.3.1 preserving pre-existing Tensor entries.
  if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
    return cleanup_and_error();
  }

  3.3.2 // Parse inputs/outputs
  (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
  (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));

  // Finally setup nodes and tensors
  if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
    return cleanup_and_error();


  if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
    return cleanup_and_error();

  return kTfLiteOk;
}

model结构体


    model_->version();
    BuildLocalIndexToRegistrationMapping(), model_->operator_codes();
    model_->subgraphs();
    model_->buffers();    

SubGraph


  const tflite::SubGraph* subgraph = (*subgraphs)[0];
  operators = subgraph->operators();
  tensors = subgraph->tensors();

new Interpreter //到这里才创建了Interpreter对象


interpreter->reset(new Interpreter(error_reporter_));    

Interpreter::Interpreter(ErrorReporter* error_reporter)
    : error_reporter_(error_reporter ? error_reporter
                                     : DefaultErrorReporter()) {
  context_.impl_ = static_cast<void*>(this);
  context_.ResizeTensor = ResizeTensor;
  context_.ReportError = ReportError;
  context_.AddTensors = AddTensors;
  context_.tensors = nullptr;
  context_.tensors_size = 0;
  context_.gemm_context = nullptr;

  // Invalid to call these these except from TfLiteDelegate
  context_.GetNodeAndRegistration = nullptr;
  context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
  context_.GetExecutionPlan = nullptr;

  // Reserve some space for the tensors to avoid excessive resizing.
  tensors_.reserve(kSlotsToReserve);
  nodes_and_registration_.reserve(kSlotsToReserve);
  next_execution_plan_index_to_prepare_ = 0;
  UseNNAPI(false);
}

创建Interpreter后,赋值tensors/input/output/Op等

向SubGraph中添加 Tensors


  (**interpreter).AddTensors(tensors->Length())
  // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
  // The value pointed to by `first_new_tensor_index` will be set to the
  // index of the first new tensor if `first_new_tensor_index` is non-null.
  TfLiteStatus AddTensors(int tensors_to_add,
                          int* first_new_tensor_index = nullptr);

TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
                                     int* first_new_tensor_index) {
  int base_index = tensors_.size();// 这里是 0 吗?
  if (first_new_tensor_index) *first_new_tensor_index = base_index;
  tensors_.resize(tensors_.size() + tensors_to_add);
  for (int i = base_index; i < tensors_.size(); i++) {
    memset(&tensors_[i], 0, sizeof(tensors_[i]));
  }
  context_.tensors = tensors_.data();
  context_.tensors_size = tensors_.size();
  return kTfLiteOk;
}


设置SubGrah的 input/output


  (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
  (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));

    TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
      TF_LITE_ENSURE_OK(&context_,
                    CheckTensorIndices("inputs", inputs.data(), inputs.size()));
      inputs_ = std::move(inputs);
      return kTfLiteOk;
    }

    TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
      TF_LITE_ENSURE_OK(
          &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size()));
      outputs_ = std::move(outputs);
      return kTfLiteOk;
    }

ParseNodes and ParseTensors


  // Finally setup nodes and tensors
  if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
    return cleanup_and_error();
  if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值