TFLite实现的log:
lite/error_reporter.h
namespace tflite {
class ErrorReporter {
public:
virtual ~ErrorReporter();
virtual int Report(const char* format, va_list args) = 0;
int Report(const char* format, ...);
int ReportError(void*, const char* format, ...);
};
// An error reporter that simplify writes the message to stderr.
// 使用struct 和 class的基本功能是一致的,但默认权限不同:public/private
struct StderrReporter : public ErrorReporter {
int Report(const char* format, va_list args) override;
};
// Return the default error reporter (output to stderr).
ErrorReporter* DefaultErrorReporter();
}
基类ErrorReport向外提供的接口是:Report/ ReportError
子类要实现的接口是 virtual int Report(const char* format, va_list args) = 0;
这里实现了一个默认子类StderrReporter,实现了Report: 打印到stderr:
stderr是什么?
int StderrReporter::Report(const char* format, va_list args) {
const int result = vfprintf(stderr, format, args);
fputc('\n', stderr);
return result;
}
DefaultErrorReporter函数只是实例化了StderrReport
ErrorReporter* DefaultErrorReporter() {
static StderrReporter* error_reporter = new StderrReporter;
return error_reporter;
}
TFLite使用android系统输出log,可以理解为把TFLite log输出到Android系统java空间,更精确的描述是把TFLite的错误信息输出到java空间
lite/java/src/main/native/exception_jni.h
//是个BufferErrorReporter, 就是把error保存在buffer
class BufferErrorReporter : public tflite::ErrorReporter {
public:
BufferErrorReporter(JNIEnv* env, int limit);
virtual ~BufferErrorReporter();
int Report(const char* format, va_list args) override;
const char* CachedErrorMessage();
private:
char* buffer_;
int start_idx_ = 0;
int end_idx_ = 0;
};
BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) {
buffer_ = new char[limit];// 根据java传入的大小,分配memory
if (!buffer_) {
throwException(env, kNullPointerException,
"Malloc of BufferErrorReporter to hold %d char failed.",
limit);
return;
}
start_idx_ = 0;//初始化成员变量,用于管理memory
end_idx_ = limit - 1;
}
//获得存储信息的memory接口
const char* BufferErrorReporter::CachedErrorMessage() { return buffer_; }
//类似pritf, 把args以format的格式保存到上面分配的memory
int BufferErrorReporter::Report(const char* format, va_list args) {
int size = 0;
if (start_idx_ < end_idx_) {
size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args);
}
start_idx_ += size;
return size;
}
怎样把BufferErrorReporter里的内容输出哪?
使用类外的函数 throwException
//该函数的功能是把要输出的内容copy到临时分配的memory,然后调用jni env函数输出
//也许能猜出就是把BufferErrorReporter中buffer的内容最为变量内容输出
void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
const size_t max_msg_len = 512;
auto* message = static_cast<char*>(malloc(max_msg_len));
if (vsnprintf(message, max_msg_len, fmt, args) >= 0) {
env->ThrowNew(env->FindClass(clazz), message);
} else {
env->ThrowNew(env->FindClass(clazz), "");
}
free(message);
va_end(args);
}
//如下面的例子:
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter:
if (status != kTfLiteOk) {
throwException(env, kIllegalArgumentException,
"Cannot create interpreter: %s",
error_reporter->CachedErrorMessage());
}
可以看出TFLite log(android)输出log的方法是把log先保存到BufferErrorReporter, 然后通过
jni函数throwException到java环境把log输出。
现在反过来,何时创建了BufferErrorReporter?
lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) {
modelByteBuffer = mappedByteBuffer;
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle);
}
// createErrorReporter -> native层, 创建后返回给java
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
JNIEnv* env, jclass clazz, jint size) {
BufferErrorReporter* error_reporter =
new BufferErrorReporter(env, static_cast<int>(size));
return reinterpret_cast<jlong>(error_reporter);
}
//java层会使用从native返回的 errorHandle 创建Model/ Interpreter,
//也就说Model/ Interpreter使用上面的机制向外输出 log.
下面看下TFLite代码的log输入
1. model.cc
class FlatBufferModel {
private: //成员变量
ErrorReporter* error_reporter_;
}
//把jni创建的BufferErrorReporter传递到TFLite
createModelWithBuffer(modelByteBuffer, errorHandle)
--> Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer
--> tflite::FlatBufferModel::BuildFromBuffer(
buf, static_cast<size_t>(capacity), error_reporter)
--> FlatBufferModel(buffer, buffer_size, error_reporter)
FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) {
allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
if (!allocation_->valid()) return;
model_ = ::tflite::GetModel(allocation_->base());
}
使用方法:
bool FlatBufferModel::CheckModelIdentifier() const {
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
error_reporter_->Report(
"Model provided has model identifier '%c%c%c%c', should be '%s'\n",
ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
return false;
}
return true;
}
2. Interpreter
class Interpreter {
// The error reporter delegate that tflite will forward queries errors to.
ErrorReporter* error_reporter_;
// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.
TfLiteContext context_;
}
//interpreter: 可以直接使用成员变量 error_reporter_,
3. TfLiteContext
TFLiteContext的errorReport是通过Interpreter的errorReport实现的;
TfLiteContext/Interpreter相互是数据成员,
typedef struct TfLiteContext {
// Request that a error be reported with format string msg.
void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
// opaque full context ptr (an opaque c++ data structure)
void* impl_;
}
Interpreter::Interpreter(ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) {
context_.impl_ = static_cast<void*>(this);
context_.ReportError = ReportError;
}
//再看TFLiteContext的 ReportError,它的实现是Interpreter的ReportError
context_.ReportError = ReportError;
context_.impl_ = static_cast<void*>(this);
void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) {
va_list args;
va_start(args, format);
auto* f = static_cast<Interpreter*>(context->impl_);
// Note here that context->impl_ is recovering the this pointer for an
// instance of Interpreter to call into the member function ReportErrorImpl
// (this function is static).
f->ReportErrorImpl(format, args);
va_end(args);
}
void Interpreter::ReportErrorImpl(const char* format, va_list args) {
error_reporter_->Report(format, args);
}
知道TFLite errorReporter的原理,就可以debug,把想打的信息打印出来;
原来以为可以通过throwException的方式,但是app直接crash了
env->ThrowNew(env->FindClass(clazz), message)本来就是抛出异常的,!!
尝试使用DefaultErrorReporter
ErrorReporter* DefaultErrorReporter() {
static StderrReporter* error_reporter = new StderrReporter;
return error_reporter;
}
int StderrReporter::Report(const char* format, va_list args) {
const int result = vfprintf(stderr, format, args);
fputc('\n', stderr);
return result;
}
06-26 04:57:12.119 5896 5896 E tflite : xxx test
06-26 04:57:12.121 5896 5896 E tflite : xxx test in jni
结论
如果想debug TFLite的代码、输出log,可以使用DefaultErrorReporter;
jni: tflite::DefaultErrorReporter()->Report("xxx test in jni\n");
tflite: ErrorReporter* testLog = DefaultErrorReporter();
testLog->Report("wenshuai test\n");
打印tflite文件的tensor 和node
int main(int argc, char* argv[]) {
if(argc != 2) {
fprintf(stderr, "minimal <tflite model>\n");
return 1;
}
const char* filename = argv[1];
// Load model
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(filename);
TFLITE_MINIMAL_CHECK(model != nullptr);
// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;//需要赋值吗?
InterpreterBuilder builder(*model.get(), resolver);
std::unique_ptr<Interpreter> interpreter;
builder(&interpreter);
TFLITE_MINIMAL_CHECK(interpreter != nullptr);
// Allocate tensor buffers.
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
printf("=== Pre-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(interpreter.get()); //interpreter.get()返回的是指向interpreter的指针
// Fill input buffers
// TODO(user): Insert code to fill input tensors
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
tflite::DynamicBuffer buf;
buf.AddString(sentence.data(), sentence.length());
buf.WriteToTensor(input);
// Run inference
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
printf("\n\n=== Post-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(interpreter.get());
// Read output buffers
// TODO(user): Insert getting data out code.
TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);
return 0;
}
代码路径 lite/optional_debug_tools.cc
// Prints a dump of what tensors and what nodes are in the interpreter.
void PrintInterpreterState(Interpreter* interpreter) {
printf("Interpreter has %zu tensors and %zu nodes\n",
interpreter->tensors_size(), interpreter->nodes_size());
printf("Inputs:");
PrintIntVector(interpreter->inputs());
printf("Outputs:");
PrintIntVector(interpreter->outputs());
printf("\n");
>> for (int tensor_index = 0; tensor_index < interpreter->tensors_size();
tensor_index++) {
TfLiteTensor* tensor = interpreter->tensor(tensor_index);
printf("Tensor %3d %-20s %10s %15s %10zu bytes (%4.1f MB) ", tensor_index,
tensor->name, TensorTypeName(tensor->type),
AllocTypeName(tensor->allocation_type), tensor->bytes,
(static_cast<float>(tensor->bytes) / (1 << 20)));
PrintTfLiteIntVector(tensor->dims);
}
printf("\n");
>> for (int node_index = 0; node_index < interpreter->nodes_size();
node_index++) {
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg =
interpreter->node_and_registration(node_index);
const TfLiteNode& node = node_and_reg->first;
const TfLiteRegistration& reg = node_and_reg->second;
if (reg.custom_name != nullptr) {
printf("Node %3d Operator Custom Name %s\n", node_index, reg.custom_name);
} else {
printf("Node %3d Operator Builtin Code %3d\n", node_index,
reg.builtin_code);
}
printf(" Inputs:");
PrintTfLiteIntVector(node.inputs);
printf(" Outputs:");
PrintTfLiteIntVector(node.outputs);
}
}
=== Pre-invoke Interpreter State ===
Interpreter has 90 tensors and 31 nodes
Inputs: 88
Outputs: 87
Tensor 0 MobilenetV1/Logits/AvgPool_1a/AvgPool kTfLiteUInt8 kTfLiteArenaRw 1024 bytes ( 0.0 MB) 1 1 1 1024
Tensor 1 MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd kTfLiteUInt8 kTfLiteArenaRw 1001 bytes ( 0.0 MB) 1 1 1 1001
Tensor 2 MobilenetV1/Logits/Conv2d_1c_1x1/Conv2D_bias kTfLiteInt32 kTfLiteMmapRo 4004 bytes ( 0.0 MB) 1001
Tensor 3 MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/FakeQuantWithMinMaxVars kTfLiteUInt8 kTfLiteMmapRo 1025024 bytes ( 1.0 MB) 1001 1 1 1024
Tensor 4 MobilenetV1/MobilenetV1/Conv2d_0/Conv2D_Fold_bias kTfLiteInt32 kTfLiteMmapRo 128 bytes ( 0.0 MB) 32
Tensor 5 MobilenetV1/MobilenetV1/Conv2d_0/Relu6 kTfLiteUInt8 kTfLiteArenaRw 401408 bytes ( 0.4 MB) 1 112 112 32