TFLite: 代码分析(1): 获得模型文件 flatbuffers对象

org.tensorflow.lite java代码组成

native相关的代码

org.tensorflow.lite java层包含的文件

org.tensorflow.lite java层有以下几个文件,其中最主要的文件是Interpreter.java

1]DataType.java

the type of elements in a TensorFlow Lite {Tensor}, 或者说flatbuffers支持的基本类型

2]Delegate.java

Wrapper for a native TensorFlow Lite Delegate, 只是一个interface


3]Interpreter.java

  Driver class to drive model inference with TensorFlow Lite.
 * A {Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
 * are executed for model inference.
 * For example, if a model takes only one input and returns only one output:
 * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
 *   interpreter.run(input, output);
 * }


4]NativeInterpreterWrapper.java

An internal wrapper that wraps native interpreter and controls model execution.

并不是java层函数直接调到native层,中间还有NativeInterpreterWrapper


5]package-info.java

没有什么信息,后期扩展?


6]tensorFlowLite.java

Static utility methods loading the TensorFlowLite runtime.

7]Tensor.java

A typed multi-dimensional array used in Tensorflow Lite

 

Interpreter.java 接口

 

1]较之前的版本引入了static class Options: 

An options class for controlling runtime interpreter behavior

这里说的运行时interpreter的行为包括:

Sets the number of threads to be used for ops that support multi-threading.

Sets whether to use NN API (if available) for op execution

Sets whether to allow float16 precision for FP32 calculation when possible

Adds a {Delegate} to be applied during interpreter creation

2]什么是NNAPI? 什么是Delegate?

https://www.tensorflow.org/lite/performance/best_practices

NNAPI, Delegate都是加速用的,需要有对应的实现如ArmNN driver for the Android Neural Networks API。另外

What is a TensorFlow Lite delegate?
A TensorFlow Lite delegate is a way to delegate part or all of graph execution to another executor.

Why should I use delegates?
Running inference on compute-heavy machine learning models on mobile devices is resource demanding 
due to the devices' limited processing and power.

Instead of relying on the CPU, some devices have hardware accelerators, such as GPU or DSP, 
that allows for better performance and higher energy efficiency.

3] Interpreter的实现

  /** 
   * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
   * {@link #Options}.
   *
   * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
   * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
   * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
   */
  public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
    wrapper = new NativeInterpreterWrapper(byteBuffer, options);

    //这里wrapper仍然是个java对象
  } 

  3.1] NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {                                                                               
    this.modelByteBuffer = buffer;
    long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
    long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
    init(errorHandle, modelHandle, options);
  }
 

  3.1.1] private void init(long errorHandle, long modelHandle, Interpreter.Options options) {
    this.errorHandle = errorHandle;
    this.modelHandle = modelHandle;
    this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);

   //这里分配tensors是干啥?
    this.inputTensors = new Tensor[getInputCount(interpreterHandle)];
    this.outputTensors = new Tensor[getOutputCount(interpreterHandle)];

    allocateTensors(interpreterHandle, errorHandle);
    this.isMemoryAllocated = true;
  }
下面分析上面创建Interpreter过程中出现的createModelWithBuffer and createInterpreter

3.1.2 createModelWithBuffer

private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle);

./java/src/main/native/nativeinterpreterwrapper_jni.cc:303:Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer()

  JNIEXPORT jlong JNICALL
  Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
      JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) {
    BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle);
    const char* buf = static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
    jlong capacity = env->GetDirectBufferCapacity(model_buffer);
    auto model = tflite::FlatBufferModel::BuildFromBuffer(//BuildFromBuffer是public/static函数,通过类名访问
        buf, static_cast<size_t>(capacity), error_reporter);
    return reinterpret_cast<jlong>(model.release());//调用unique_ptr的release函数返回对象指针,model赋值为null,不再管理
  }

  FlatBufferModel::BuildFromBuffer在文件 model.cc 中

  3.1.2.1 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(                                                                 
      const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
    error_reporter = ValidateErrorReporter(error_reporter);
  
    std::unique_ptr<FlatBufferModel> model;// 用unique_ptr(model)管理对象FlatBufferModel
    Allocation* allocation = new MemoryAllocation(buffer, buffer_size, error_reporter);
    model.reset(new FlatBufferModel(allocation, error_reporter));//new FlatBufferModel对象,然后model管理该对象
    if (!model->initialized()) model.reset();
    return model;
  }

  FlatBufferModel::FlatBufferModel(Allocation* allocation, ErrorReporter* error_reporter)
      : error_reporter_(ValidateErrorReporter(error_reporter)) {
    allocation_ = allocation;
    if (!allocation_->valid() || !CheckModelIdentifier()) return;
  
    model_ = ::tflite::GetModel(allocation_->base());
  }

lite/schema/schema_generated.h 中tflite::Model *GetModel

inline const tflite::Model *GetModel(const void *buf) {                                                                                                            
  return flatbuffers::GetRoot<tflite::Model>(buf);
}


flatbuffers::GetRoot的实现在flatbuffers/flatbuffers.h文件中

namespace flatbuffers {

template<typename T> const T *GetRoot(const void *buf) {
  return GetMutableRoot<T>(const_cast<void *>(buf));
}

}

这样最终得到了schema_v3.fbs中描述的 flatbuffers对象

table Model {
  // Version of the schema.
  version:uint;

  // A list of all operator codes used in this model. This is
  // kept in order because operators carry an index into this
  // vector.
  operator_codes:[OperatorCode];

  // All the subgraphs of the model. The 0th is assumed to be the main
  // model.
  subgraphs:[SubGraph];

  // A description of the model.
  description:string;

  // Buffers of the model.
  // NOTE: It is required that the first entry in here is always an empty
  // buffer. This is so that the default buffer index of zero in Tensor
  // will always refer to a valid empty buffer.                              
  buffers:[Buffer];

}

4 createErrorReporter:怎样打印log?


jni层通过throwException打印log
lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
    JNIEnv* env, jclass clazz, jint size) {//java 环境到 native环境的变换

  BufferErrorReporter* error_reporter = new BufferErrorReporter(env, static_cast<int>(size));

  return reinterpret_cast<jlong>(error_reporter);
}
// class BufferErrorReporter的构造函数,就是申请一块memmory保存log
BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) {
  buffer_ = new char[limit];
  start_idx_ = 0;
  end_idx_ = limit - 1;
}

 

3.2.2 Allocation

这里Allocation指着是模型文件使用的memory,是ByteBuffer传入的,不需要分配,只是用来管理吧
class MemoryAllocation : public Allocation {
 public:
  // Allocates memory with the pointer and the number of bytes of the memory.
  // The pointer has to remain alive and unchanged until the destructor is
  // called.
  MemoryAllocation(const void* ptr, size_t num_bytes,
                   ErrorReporter* error_reporter);
  virtual ~MemoryAllocation();
  const void* base() const override;
  size_t bytes() const override;
  bool valid() const override;

 private:
  const void* buffer_;
  size_t buffer_size_bytes_ = 0;
};

// A memory allocation handle. This could be a mmap or shared memory.
class Allocation {
 public:
  Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {}
  virtual ~Allocation() {}

  // Base pointer of this allocation
  virtual const void* base() const = 0;
  // Size in bytes of the allocation
  virtual size_t bytes() const = 0;
  // Whether the allocation is valid
  virtual bool valid() const = 0;

 protected:
  ErrorReporter* error_reporter_;
};

// 把输入参数,赋值到 Allocation 成员变量
MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
                                   ErrorReporter* error_reporter)
    : Allocation(error_reporter) {
  buffer_ = ptr;
  buffer_size_bytes_ = num_bytes;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值