TFLite: interpreter的建立

TFLite执行推断的的过程

  const char* filename = argv[1];

  // Load model
  std::unique_ptr<tflite::FlatBufferModel> model =
      tflite::FlatBufferModel::BuildFromFile(filename);
  TFLITE_MINIMAL_CHECK(model != nullptr);

  // Build the interpreter
  tflite::ops::builtin::BuiltinOpResolver resolver;
  InterpreterBuilder builder(*model.get(), resolver);
  std::unique_ptr<Interpreter> interpreter;
  builder(&interpreter); //对应TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
  TFLITE_MINIMAL_CHECK(interpreter != nullptr);

  // Allocate tensor buffers.
  TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
  printf("=== Pre-invoke Interpreter State ===\n");
  tflite::PrintInterpreterState(interpreter.get());

  // Fill input buffers
  // TODO(user): Insert code to fill input tensors

  // Run inference
  TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
  printf("\n\n=== Post-invoke Interpreter State ===\n");
  tflite::PrintInterpreterState(interpreter.get());

构建interpreter过程中涉及的class: FlatBufferModel, interpreterBuilder, interpreter

 

InterpreterBuilder::operator()

创建Interpreter的代码在model.cc的operator() 重载 InterpreterBuilder::operator()

TfLiteStatus InterpreterBuilder::operator()( std::unique_ptr<Interpreter>* interpreter, int num_threads) {

    if (model_->version() != TFLITE_SCHEMA_VERSION) {
    }
  
    if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
    
    auto* subgraphs = model_->subgraphs();
    auto* buffers = model_->buffers();
    if (subgraphs->size() != 1) {
    }
 
    const tflite::SubGraph* subgraph = (*subgraphs)[0];
    auto operators = subgraph->operators();
    auto tensors = subgraph->tensors();
    if (!operators || !tensors || !buffers) {
    }

    interpreter->reset(new Interpreter(error_reporter_));
    if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk)                     
    }

    // Set num threads
    (**interpreter).SetNumThreads(num_threads);
    // Parse inputs/outputs
    (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
    (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));

    // Finally setup nodes and tensors
    if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)                                          
      return cleanup_and_error();
    if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
      return cleanup_and_error();

    std::vector<int> variables;
    for (int i = 0; i < (*interpreter)->tensors_size(); ++i) {
      auto* tensor = (*interpreter)->tensor(i);
      if (tensor->is_variable) {
        variables.push_back(i);
      }   
    }
    (**interpreter).SetVariables(std::move(variables));
}

1. model_ 从Flatbuffer model中得到 tf::Model

 const ::tflite::Model* model_;

  FlatBufferModel::FlatBufferModel(Allocation* allocation,
                                   ErrorReporter* error_reporter)
      : error_reporter_(ValidateErrorReporter(error_reporter)) {
    allocation_ = allocation;
    if (!allocation_->valid() || !CheckModelIdentifier()) return;
  
    model_ = ::tflite::GetModel(allocation_->base());
  }

  uint32_t version() const    {                                                                        
    return GetField<uint32_t>(VT_VERSION, 0);
  }

2. BuildLocalIndexToRegistrationMapping()


  //根据model文件中的model_->operator_codes(),查询OpResolve库得到TfLiteRegistration
  TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
    TfLiteStatus status = kTfLiteOk;
    auto opcodes = model_->operator_codes();
    for (const OperatorCode* opcode : *opcodes) {//这里没有使用Get(i),用的方式不是数组形式
      const TfLiteRegistration* registration = nullptr;
      status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,                                       
                                         &registration);
      flatbuffer_op_index_to_registration_.push_back(registration);
    }
    return status;
  }
  

  根据flatbuffer model里保存的OperatorCode,获得 opcode->builtin_code(),调用通用接口
  GetRegistrationFromOpCode,其中依赖op_resolver_,这个值根据用户定义变化,如定义BuiltinOpResolver
BuiltinOpResolver的继承关系如下:


2.1 class OpResolver


core/api/op_resolver.h 
namespace tflite {

// Abstract interface that returns TfLiteRegistrations given op codes or custom
// op names. This is the mechanism that ops being referenced in the flatbuffer
// model are mapped to executable function pointers (TfLiteRegistrations).
class OpResolver {
 public:
  // Finds the op registration for a builtin operator by enum code.
  virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 
                                           int version) const = 0;
  // Finds the op registration of a custom operator by op name.
  virtual const TfLiteRegistration* FindOp(const char* op, 
                                           int version) const = 0;
  virtual ~OpResolver() {}
};

// Handles the logic for converting between an OperatorCode structure extracted
// from a flatbuffer and information about a registered operator implementation.
TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
                                       const OpResolver& op_resolver,
                                       ErrorReporter* error_reporter,
                                       const TfLiteRegistration** registration);

}  // namespace tflite

TfLiteStatus GetRegistrationFromOpCode(
    const OperatorCode* opcode, const OpResolver& op_resolver,
    ErrorReporter* error_reporter, const TfLiteRegistration** registration) {
  TfLiteStatus status = kTfLiteOk;
  *registration = nullptr;

  auto builtin_code = opcode->builtin_code();
  int version = opcode->version();

  if (builtin_code != BuiltinOperator_CUSTOM) {
    *registration = op_resolver.FindOp(builtin_code, version);

  } else if (!opcode->custom_code()) {
    //------
  } else {
    const char* name = opcode->custom_code()->c_str();
    *registration = op_resolver.FindOp(name, version);
  }
  return status;
}


2.2 MutableOpResolver


mutable_op_resolver.h
// An OpResolver that is mutable, also used as the op in gen_op_registration.
// A typical usage:
//   MutableOpResolver resolver;
//   resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
//   resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
//   InterpreterBuilder(model, resolver)(&interpreter);
class MutableOpResolver : public OpResolver {
 public:
  const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 
                                   int version) const override;
  const TfLiteRegistration* FindOp(const char* op, int version) const override;
  void AddBuiltin(tflite::BuiltinOperator op, 
                  const TfLiteRegistration* registration, int min_version = 1,
                  int max_version = 1); 
  void AddCustom(const char* name, const TfLiteRegistration* registration,
                 int min_version = 1, int max_version = 1); 
  void AddAll(const MutableOpResolver& other);

 private:
  typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
  typedef std::pair<std::string, int> CustomOperatorKey;

  std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
                     op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
      builtins_;
  std::unordered_map<CustomOperatorKey, TfLiteRegistration,
                     op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
      custom_ops_;
};

2.3 kernel/register


kernel/register.h
namespace tflite {
namespace ops {
namespace builtin {

class BuiltinOpResolver : public MutableOpResolver {
 public:
  BuiltinOpResolver();

  const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 
                                   int version) const override;
  const TfLiteRegistration* FindOp(const char* op, int version) const override;
};

}  // namespace builtin
}  // namespace ops
}  // namespace tflite

BuiltinOpResolver::BuiltinOpResolver() {//构造函数创建BuiltinOpResolver
 AddBuiltin(BuiltinOperator_RELU, Register_RELU());
 ----
 AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
}

2.4 如果定制opResolve

比如在TFLite占用的空间太大,需要裁剪;或者需要增加operator

2.4.1 在tflite::ops::builtin::BuiltinOpResolver基础上更改

在register.cc文件中

删除TfLiteRegistration* Register_RNN();

删除AddBuiltin(BuiltinOperator_RNN, Register_RNN());

BUILD文件中删除lstm.cc不参与编译

name = "builtin_op_kernels",

"lstm.cc",

2.4.2 写和register.h 和register.cc对应的类

 

3. (**interpreter).AddTensors(tensors->Length())


赋值std::vector<TfLiteTensor> tensors_,且初始化为0,
为避免copy,且同步到TfLiteContext:context_,该函数并没有从flatbuffer中取出tensor内容

3.1 std::vector<TfLiteTensor> tensors_


  // Tensors needed by the interpreter. Use `AddTensors` to add more blank
  // tensor entries. Note, `tensors_.data()` needs to be synchronized to the
  // `context_` whenever this std::vector is reallocated. Currently this
  // only happens in `AddTensors()`.
  std::vector<TfLiteTensor> tensors_;

3.2 AddTensors声明


  // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
  // The value pointed to by `first_new_tensor_index` will be set to the
  // index of the first new tensor if `first_new_tensor_index` is non-null.

  TfLiteStatus AddTensors(int tensors_to_add,                                                                                                       
                          int* first_new_tensor_index = nullptr);

3.3 AddTensors实现


  TfLiteStatus Interpreter::AddTensors(int tensors_to_add,                                                                                                      
                                       int* first_new_tensor_index) {
    int base_index = tensors_.size();
    if (first_new_tensor_index) *first_new_tensor_index = base_index;
    tensors_.resize(tensors_.size() + tensors_to_add);
    for (int i = base_index; i < tensors_.size(); i++) {
      memset(&tensors_[i], 0, sizeof(tensors_[i]));
      tensors_[i].buffer_handle = kTfLiteNullBufferHandle;
    }
    context_.tensors = tensors_.data();
    context_.tensors_size = tensors_.size();
    return kTfLiteOk;
  }
 

4. interpreter setInputs/setOutputs


(**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
(**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));

4.1 FlatBufferIntArrayToVector


  //flatbuffer中指向数组的指针,再通过Get(i)获得数组的元素
  std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {//类型转换
    // Initialize shape of tensors with null shape. Empty vectors are converted
    // to nullptr for models that are constructed via flatbuffers::Pack.
    if (flat_array == nullptr) {
      return {}; 
    }
    std::vector<int> ret(flat_array->Length());
    for (int i = 0; i < flat_array->Length(); i++) {
      ret[i] = flat_array->Get(i);                                                                                                                              
    }
    return ret;
  }

4.2 std::move(inputs)


  // Array of indices representing the tensors that are inputs to the
  // interpreter.[数组的内容是inputs tensors的索引号index]
  std::vector<int> inputs_;
  TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {                                                                    
    TF_LITE_ENSURE_OK(&context_,
                      CheckTensorIndices("inputs", inputs.data(), inputs.size()));
    inputs_ = std::move(inputs);
    return kTfLiteOk;
  }

4.2.1 CheckTensorIndices


  //tensor的索引号和前面AddTensors赋值的context_.tensors_size比较判断
  TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
                                               const int* indices, int length) {
    // Making sure kOptionalTensor is not re-defined to something other than -1.
    static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1");
  
    for (int i = 0; i < length; i++) {
      int index = indices[i];
      // Continue if index == kOptionalTensor before additional comparisons below,
      // size_t(-1) is always >= context_tensors_size.
      if (index == kOptionalTensor) {
        continue;
      }    
      if (index < 0 || static_cast<size_t>(index) >= context_.tensors_size) {
        ReportError(&context_, "Invalid tensor index %d in %s\n", index, label);                                                     
        consistent_ = false;
        return kTfLiteError;
      }    
    }
    return kTfLiteOk;
  }

5. InterpreterBuilder::ParseNodes


TfLiteStatus InterpreterBuilder::ParseNodes(
      const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, Interpreter* interpreter) {

    TfLiteStatus status = kTfLiteOk;
    // Reduce the number of redundant allocations
    interpreter->ReserveNodes(operators->Length());
  
    for (int i = 0; i < operators->Length(); ++i) {
      const auto* op = operators->Get(i);//又见Get(i)
      int index = op->opcode_index();// 这里的index和opResove没有关系,是model_->operator_codes()的index

      const TfLiteRegistration* registration = flatbuffer_op_index_to_registration_[index];
  
      BuiltinOperator op_type = static_cast<BuiltinOperator>(registration->builtin_code);

      if (op->custom_options()) {
        interpreter->AddNodeWithParameters(
            FlatBufferIntArrayToVector(op->inputs()),
            FlatBufferIntArrayToVector(op->outputs()),
            reinterpret_cast<const char*>(op->custom_options()->data()),
            op->custom_options()->size(), nullptr, registration);
      } else {
        void* builtin_data = nullptr;
        MallocDataAllocator malloc_allocator;
        TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, &malloc_allocator, &builtin_data));
        interpreter->AddNodeWithParameters(
            FlatBufferIntArrayToVector(op->inputs()),
            FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data, registration);
      }
    }
}

5.1 ReserveNodes


  // Ensure the internal node storage memory allocates at least `count`
  // spots for node. NOTE, this doesn't actually add operators. This is an
  // efficiency optimization that is subject to change.
  void ReserveNodes(int count); 
  void Interpreter::ReserveNodes(int count) {                                                                  
    nodes_and_registration_.reserve(count);
  }
  
  // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
  // function pointers to actual implementation.
  std::vector<std::pair<TfLiteNode, TfLiteRegistration>>                                                                                          
      nodes_and_registration_;

5.2 获得Node对应的处理函数


flatbuffer_op_index_to_registration_[index] 获得 TfLiteRegistration
std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration
依赖上面BuildLocalIndexToRegistrationMapping()建立的flatbuffer_op_index_to_registration_

5.3 获得处理函数的参数

分为保存builtin参数的内存和从flatbuff中得到builtin参数

5.3.1 获得内存

最终使用malloc方法
flatbuffer_conversions.h
// Interface class for builtin data allocations.
class BuiltinDataAllocator {
 public:
  virtual void* Allocate(size_t size) = 0;                                         
  virtual void Deallocate(void* data) = 0;

  // Allocate a structure, but make sure it is a POD structure that doesn't
  // require constructors to run. The reason we do this, is that Interpreter's C
  // extension part will take ownership so destructors  will not be run during
  // deallocation.
  template <typename T>
  T* AllocatePOD() {
    static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
    return static_cast<T*>(this->Allocate(sizeof(T)));
  }

  virtual ~BuiltinDataAllocator() {}
};

model.cc
  // Used to determine how the op data parsing function creates its working space.
  class MallocDataAllocator : public BuiltinDataAllocator {                                                                                                                                                 
   public:
    void* Allocate(size_t size) override { return malloc(size); }
    void Deallocate(void* data) override { free(data); }
  };

5.3.2 ParseOpData 获得 builtin_data


  // Parse the appropriate data out of the op.
  //
  // This handles builtin data explicitly as there are flatbuffer schemas.
  // If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
  // need to be released by calling `free`.`
  // If it returns kTfLiteError, `builtin_data` will be `nullptr`.
  TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
                           ErrorReporter* error_reporter,
                           BuiltinDataAllocator* allocator, void** builtin_data) {
    *builtin_data = nullptr;
    switch (op_type) {
      case BuiltinOperator_CONV_2D: {
        TfLiteConvParams* params = allocator->AllocatePOD<TfLiteConvParams>();
        if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
          params->padding = parse_padding(conv_params->padding());
          params->stride_width = conv_params->stride_w();
          params->stride_height = conv_params->stride_h();
          params->activation =
              parse_activation(conv_params->fused_activation_function());
  
          params->dilation_width_factor = conv_params->dilation_w_factor();
          params->dilation_height_factor = conv_params->dilation_h_factor();
        }
        *builtin_data = reinterpret_cast<void*>(params);
        break;
      }
  }

5.4 AddNodeWithParameters

核心操作的是nodes_and_registration_,并execution_plan_.push_back(new_node_index)

TfLiteStatus Interpreter::AddNodeWithParameters(                                                                                                                                                          
      const std::vector<int>& inputs, const std::vector<int>& outputs,
      const char* init_data, size_t init_data_size, void* builtin_data,
      const TfLiteRegistration* registration, int* node_index) {

    int new_node_index = nodes_and_registration_.size();//主要size()不是长度,是说用了几个
    if (node_index) *node_index = new_node_index;
    nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
    //获得一个node_and_reg
    auto& node_and_reg = nodes_and_registration_.back();
    //first指的是pair中的node,如果其中有数据则Free
    TfLiteNode& node = node_and_reg.first;
    if (node.inputs) TfLiteIntArrayFree(node.inputs);
    if (node.outputs) TfLiteIntArrayFree(node.outputs);
    if (node.temporaries) TfLiteIntArrayFree(node.temporaries);

    // NOTE, here we are not using move semantics yet, since our internal
    // representation isn't std::vector, but in the future we would like to avoid
    // copies, so we want the interface to take r-value references now.
    node.inputs = ConvertVectorToTfLiteIntArray(inputs);
    node.outputs = ConvertVectorToTfLiteIntArray(outputs);
    node.temporaries = TfLiteIntArrayCreate(0);

    //如果有init_data,则调用OpInit
    if (init_data) {
      node.user_data = OpInit(*registration, init_data, init_data_size);
    } else {
      node.user_data = OpInit(*registration,                                   
                 reinterpret_cast<const char*>(builtin_data_deleter.get()), 0);
    }
    //builtinData  
    node.builtin_data = builtin_data_deleter.release();
    // TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size`
    // properly for nodes generated by ReplaceSubgraphsWithDelegateKernels.
 
    //custom data 
    if (registration->builtin_code == BuiltinOperator_CUSTOM) {
      // When it's a CUSTOM op, the `custom_options` field in the Flatbuffer
      // `Operator` table is passed in.
      node.custom_initial_data = init_data;
      node.custom_initial_data_size = init_data_size;
    } else {
      node.custom_initial_data = nullptr;
      node.custom_initial_data_size = 0;
    }

    node.delegate = nullptr;
    //sencode指pair中的 registration
    node_and_reg.second = *registration;
    execution_plan_.push_back(new_node_index);
    return kTfLiteOk;
  }

5.4.1 nodes_and_registration_


  // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
  // function pointers to actual implementation.
  std::vector<std::pair<TfLiteNode, TfLiteRegistration>>                                                                                          
      nodes_and_registration_;

5.4.2 op_reg:OpInit


  // Give 'op_reg' a chance to initialize itself using the contents of
  // 'buffer'.
  void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,                                                                            
               size_t length) { 
    if (op_reg.init == nullptr) return nullptr;
    return op_reg.init(&context_, buffer, length);
  }

5.4.3 execution_plan_


  // WARNING: This is an experimental interface that is subject to change.
  // This is a list of node indices (to index into nodes_and_registration).
  // This represents a valid topological sort (dependency ordered) execution
  // plan. In particular, it is valid for this ordering to contain only a
  // subset of the node indices.
  std::vector<int> execution_plan_;

 execution_plan_.push_back(new_node_index);

6. ParseTensors(buffers, tensors, interpreter->get())

由index得到 context_.tensors[tensor_index],然后赋值如内存分配类型ro/rw

  TfLiteStatus InterpreterBuilder::ParseTensors(
      const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
      const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
      Interpreter* interpreter) {

    for (int i = 0; i < tensors->Length(); ++i) {
      const auto* tensor = tensors->Get(i);//又见Get(i)
      std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());

      TfLiteQuantizationParams quantization;
      quantization.scale = 0;
      quantization.zero_point = 0;
      auto* q_params = tensor->quantization();
      if (q_params) {}

      TfLiteType type;
      if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=                                                    
          kTfLiteOk) {
      }

      size_t buffer_size = 0;
      const char* buffer_ptr;
      TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
      bool is_variable = tensor->is_variable();

      if (buffer_ptr) {
        if (is_variable) {
          status = kTfLiteError;
        }
  
        if (interpreter->SetTensorParametersReadOnly(
                i, type, get_name(tensor), dims, quantization, buffer_ptr,
                buffer_size, allocation_) != kTfLiteOk) {
          status = kTfLiteError;
        }
      } else {
        if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor),
                                                      dims, quantization,
                                                      is_variable) != kTfLiteOk) {
          status = kTfLiteError;
        }
      }
    }
}

6.1 ConvertTensorType(tensor->type(), &type, error_reporter_)

flatbuff model和tflite中tensor类型定义是不同的
Flatbuffer 中的tensor类型
enum TensorType {
  TensorType_FLOAT32 = 0,                                                                                            
  TensorType_FLOAT16 = 1,
  TensorType_INT32 = 2,
  TensorType_UINT8 = 3,
  TensorType_INT64 = 4,
  TensorType_STRING = 5,
  TensorType_BOOL = 6,
  TensorType_INT16 = 7,
  TensorType_COMPLEX64 = 8,
  TensorType_MIN = TensorType_FLOAT32,
  TensorType_MAX = TensorType_COMPLEX64
};

TFLite中的类型
typedef enum {
  kTfLiteNoType = 0,
  kTfLiteFloat32 = 1,                                                                                                                  
  kTfLiteInt32 = 2,
  kTfLiteUInt8 = 3,
  kTfLiteInt64 = 4,
  kTfLiteString = 5,
  kTfLiteBool = 6,
  kTfLiteInt16 = 7,
  kTfLiteComplex64 = 8,
} TfLiteType;

6.2 SetTensorParametersReadOnly


  // Set description of inputs/outputs/data/fptrs for node `node_index`.
  // This variant assumes an external buffer has been allocated of size
  // bytes. The lifetime of buffer must be ensured to be greater or equal
  // to Interpreter.
  inline TfLiteStatus SetTensorParametersReadOnly(
      int tensor_index, TfLiteType type, const char* name,
      const std::vector<int>& dims, TfLiteQuantizationParams quantization,
      const char* buffer, size_t bytes,
      const Allocation* allocation = nullptr) {
    return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),                                               
                                       dims.data(), quantization, buffer, bytes,
                                       allocation);
  }


  TfLiteStatus Interpreter::SetTensorParametersReadOnly(
      int tensor_index, TfLiteType type, const char* name, const size_t rank,
      const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
      size_t bytes, const Allocation* allocation) {
    TF_LITE_ENSURE(&context_,
                   tensor_index < context_.tensors_size && tensor_index >= 0);
    // For most tensors we know exactly how much memory is necessary so we can
    // ensure the buffer is large enough. However, we need to skip string tensors
    // because their sizes change with the contents of the individual strings.
    if (type != kTfLiteString) {
      size_t required_bytes;
      TF_LITE_ENSURE_OK(&context_,
                        BytesRequired(type, dims, rank, &required_bytes));
      TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
    }
  
    TfLiteTensor& tensor = context_.tensors[tensor_index];
    if (type == tensor.type &&
        EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) {
      // Fast path which does not invalidate the invokable property.
      TfLiteTensorDataFree(&tensor);
      tensor.data.raw = const_cast<char*>(buffer);
      if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims);
      tensor.params = quantization;
      tensor.allocation_type = kTfLiteMmapRo;
      tensor.allocation = allocation; //tensor的allocation怎么用?分配内存?
    } else {
      state_ = kStateUninvokable;
      TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
                        quantization, const_cast<char*>(buffer), bytes,
                        kTfLiteMmapRo, allocation, false, &tensor);
    }
    return kTfLiteOk;
  }  

6.2.1 BytesRequired


  TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
                                          size_t dims_size, size_t* bytes) {
    size_t count = 1;
    for (int k = 0; k < dims_size; k++) count *= dims[k];
    switch (type) {
      case kTfLiteFloat32:
        *bytes = sizeof(float) * count;
        break;
  }

6.2.2 EqualArrayAndTfLiteIntArray

// Checks whether a `TfLiteIntArray` and an int array have matching elements.
// The caller must guarantee that 'b' has at least 'b_size' elements.
bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,                                                       
                                 const int* b);

  bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,                                                       
                                   const int* b) {
    if (!a) return false;
    if (a->size != b_size) return false;
    for (int i = 0; i < a->size; ++i) {
      if (a->data[i] != b[i]) return false;
    }
    return true;
  }

6.2.3 TfLiteTensorDataFree


// Free data memory of tensor `t`;
void TfLiteTensorDataFree(TfLiteTensor* t);
  void TfLiteTensorDataFree(TfLiteTensor* t) { 
    if (t->allocation_type == kTfLiteDynamic && t->data.raw) {//动态分配的free
      free(t->data.raw);
    }
    t->data.raw = NULL;
  }

6.2.4 TfLiteTensorReset


  void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
                         TfLiteQuantizationParams quantization, char* buffer,
                         size_t size, TfLiteAllocationType allocation_type,
                         const void* allocation, bool is_variable,
                         TfLiteTensor* tensor) {
    TfLiteTensorFree(tensor);
    tensor->type = type;
    tensor->name = name;
    tensor->dims = dims;
    tensor->params = quantization;
    tensor->data.raw = buffer;
    tensor->bytes = size;
    tensor->allocation_type = allocation_type;
    tensor->allocation = allocation;
    tensor->is_variable = is_variable;
  }

6.3 SetTensorParametersReadWrite


  TfLiteStatus Interpreter::SetTensorParametersReadWrite(
      int tensor_index, TfLiteType type, const char* name, const size_t rank,
      const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
    size_t required_bytes = 0;
    if (type != kTfLiteString) {
      TF_LITE_ENSURE_OK(&context_,
                        BytesRequired(type, dims, rank, &required_bytes));
    }

    TfLiteAllocationType allocation_type = kTfLiteArenaRw;
    if (type == kTfLiteString) {
      if (is_variable) {
      }
      allocation_type = kTfLiteDynamic;
    } else if (is_variable) {
      allocation_type = kTfLiteArenaRwPersistent;
    }
  
    TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
                      quantization,
                      /*buffer=*/nullptr, required_bytes, allocation_type,
                      nullptr, is_variable, &context_.tensors[tensor_index]);
    return kTfLiteOk;  
  }

Rw/RwPersisten类型的memory内存分配:tensor数据使用的内存是在这里分配的,而Ro没有分配内存
  TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) {
    TfLiteTensor& tensor = *graph_info_->tensor(tensor_index);
    if (tensor.allocation_type == kTfLiteArenaRw) {
      if (allocs_[tensor_index].size != 0) {
        TF_LITE_ENSURE_STATUS(arena_.ResolveAlloc(context_, allocs_[tensor_index],
                                                  &tensor.data.raw));                                                                                                                                 
      }
    }
    if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
      TF_LITE_ENSURE_STATUS(persistent_arena_.ResolveAlloc(
          context_, allocs_[tensor_index], &tensor.data.raw));
    }
    return kTfLiteOk;
  }
 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值