org.tensorflow.lite java代码组成
native相关的代码
org.tensorflow.lite java层包含的文件
org.tensorflow.lite java层有以下几个文件,其中最主要的文件是Interpreter.java
1]DataType.java
the type of elements in a TensorFlow Lite {Tensor}, 或者说flatbuffers支持的基本类型
2]Delegate.java
Wrapper for a native TensorFlow Lite Delegate, 只是一个interface
3]Interpreter.java
Driver class to drive model inference with TensorFlow Lite.
* A {Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations
* are executed for model inference.
* For example, if a model takes only one input and returns only one output:
* try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
* interpreter.run(input, output);
* }
4]NativeInterpreterWrapper.java
An internal wrapper that wraps native interpreter and controls model execution.
并不是java层函数直接调到native层,中间还有NativeInterpreterWrapper
5]package-info.java
没有什么信息,后期扩展?
6]tensorFlowLite.java
Static utility methods loading the TensorFlowLite runtime.
7]Tensor.java
A typed multi-dimensional array used in Tensorflow Lite
Interpreter.java 接口
1]较之前的版本引入了static class Options:
An options class for controlling runtime interpreter behavior
这里说的运行时interpreter的行为包括:
Sets the number of threads to be used for ops that support multi-threading.
Sets whether to use NN API (if available) for op execution
Sets whether to allow float16 precision for FP32 calculation when possible
Adds a {Delegate} to be applied during interpreter creation
2]什么是NNAPI? 什么是Delegate?
https://www.tensorflow.org/lite/performance/best_practices
NNAPI, Delegate都是加速用的,需要有对应的实现如ArmNN driver for the Android Neural Networks API。另外
What is a TensorFlow Lite delegate?
A TensorFlow Lite delegate is a way to delegate part or all of graph execution to another executor.
Why should I use delegates?
Running inference on compute-heavy machine learning models on mobile devices is resource demanding
due to the devices' limited processing and power.
Instead of relying on the CPU, some devices have hardware accelerators, such as GPU or DSP,
that allows for better performance and higher energy efficiency.
3] Interpreter的实现
/**
* Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
* {@link #Options}.
*
* <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
* {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
wrapper = new NativeInterpreterWrapper(byteBuffer, options);
//这里wrapper仍然是个java对象
}
3.1] NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {
this.modelByteBuffer = buffer;
long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
init(errorHandle, modelHandle, options);
}
3.1.1] private void init(long errorHandle, long modelHandle, Interpreter.Options options) {
this.errorHandle = errorHandle;
this.modelHandle = modelHandle;
this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
//这里分配tensors是干啥?
this.inputTensors = new Tensor[getInputCount(interpreterHandle)];
this.outputTensors = new Tensor[getOutputCount(interpreterHandle)];
allocateTensors(interpreterHandle, errorHandle);
this.isMemoryAllocated = true;
}
下面分析上面创建Interpreter过程中出现的createModelWithBuffer and createInterpreter
3.1.2 createModelWithBuffer
private static native long createModelWithBuffer(ByteBuffer modelBuffer, long errorHandle);
./java/src/main/native/nativeinterpreterwrapper_jni.cc:303:Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer()
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) {
BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle);
const char* buf = static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
jlong capacity = env->GetDirectBufferCapacity(model_buffer);
auto model = tflite::FlatBufferModel::BuildFromBuffer(//BuildFromBuffer是public/static函数,通过类名访问
buf, static_cast<size_t>(capacity), error_reporter);
return reinterpret_cast<jlong>(model.release());//调用unique_ptr的release函数返回对象指针,model赋值为null,不再管理
}
FlatBufferModel::BuildFromBuffer在文件 model.cc 中
3.1.2.1 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;// 用unique_ptr(model)管理对象FlatBufferModel
Allocation* allocation = new MemoryAllocation(buffer, buffer_size, error_reporter);
model.reset(new FlatBufferModel(allocation, error_reporter));//new FlatBufferModel对象,然后model管理该对象
if (!model->initialized()) model.reset();
return model;
}
FlatBufferModel::FlatBufferModel(Allocation* allocation, ErrorReporter* error_reporter)
: error_reporter_(ValidateErrorReporter(error_reporter)) {
allocation_ = allocation;
if (!allocation_->valid() || !CheckModelIdentifier()) return;
model_ = ::tflite::GetModel(allocation_->base());
}
lite/schema/schema_generated.h 中tflite::Model *GetModel
inline const tflite::Model *GetModel(const void *buf) {
return flatbuffers::GetRoot<tflite::Model>(buf);
}
flatbuffers::GetRoot的实现在flatbuffers/flatbuffers.h文件中
namespace flatbuffers {
template<typename T> const T *GetRoot(const void *buf) {
return GetMutableRoot<T>(const_cast<void *>(buf));
}
}
这样最终得到了schema_v3.fbs中描述的 flatbuffers对象
table Model {
// Version of the schema.
version:uint;
// A list of all operator codes used in this model. This is
// kept in order because operators carry an index into this
// vector.
operator_codes:[OperatorCode];
// All the subgraphs of the model. The 0th is assumed to be the main
// model.
subgraphs:[SubGraph];
// A description of the model.
description:string;
// Buffers of the model.
// NOTE: It is required that the first entry in here is always an empty
// buffer. This is so that the default buffer index of zero in Tensor
// will always refer to a valid empty buffer.
buffers:[Buffer];
}
4 createErrorReporter:怎样打印log?
jni层通过throwException打印log
lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
JNIEnv* env, jclass clazz, jint size) {//java 环境到 native环境的变换
BufferErrorReporter* error_reporter = new BufferErrorReporter(env, static_cast<int>(size));
return reinterpret_cast<jlong>(error_reporter);
}
// class BufferErrorReporter的构造函数,就是申请一块memmory保存log
BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) {
buffer_ = new char[limit];
start_idx_ = 0;
end_idx_ = limit - 1;
}
3.2.2 Allocation
这里Allocation指着是模型文件使用的memory,是ByteBuffer传入的,不需要分配,只是用来管理吧
class MemoryAllocation : public Allocation {
public:
// Allocates memory with the pointer and the number of bytes of the memory.
// The pointer has to remain alive and unchanged until the destructor is
// called.
MemoryAllocation(const void* ptr, size_t num_bytes,
ErrorReporter* error_reporter);
virtual ~MemoryAllocation();
const void* base() const override;
size_t bytes() const override;
bool valid() const override;
private:
const void* buffer_;
size_t buffer_size_bytes_ = 0;
};
// A memory allocation handle. This could be a mmap or shared memory.
class Allocation {
public:
Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {}
virtual ~Allocation() {}
// Base pointer of this allocation
virtual const void* base() const = 0;
// Size in bytes of the allocation
virtual size_t bytes() const = 0;
// Whether the allocation is valid
virtual bool valid() const = 0;
protected:
ErrorReporter* error_reporter_;
};
// 把输入参数,赋值到 Allocation 成员变量
MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
ErrorReporter* error_reporter)
: Allocation(error_reporter) {
buffer_ = ptr;
buffer_size_bytes_ = num_bytes;
}