TFLite: 代码分析(4): android log输出

TFLite实现的log:


lite/error_reporter.h

namespace tflite {

class ErrorReporter {
 public:
  virtual ~ErrorReporter();
  virtual int Report(const char* format, va_list args) = 0;
  int Report(const char* format, ...);
  int ReportError(void*, const char* format, ...);
};

// An error reporter that simplify writes the message to stderr.
// 使用struct 和 class的基本功能是一致的,但默认权限不同:public/private

struct StderrReporter : public ErrorReporter {
  int Report(const char* format, va_list args) override;
};

// Return the default error reporter (output to stderr).
ErrorReporter* DefaultErrorReporter();

}

基类ErrorReport向外提供的接口是:Report/ ReportError
子类要实现的接口是 virtual int Report(const char* format, va_list args) = 0;

这里实现了一个默认子类StderrReporter,实现了Report: 打印到stderr:
stderr是什么?

int StderrReporter::Report(const char* format, va_list args) {
  const int result = vfprintf(stderr, format, args);
  fputc('\n', stderr);
  return result;
}

DefaultErrorReporter函数只是实例化了StderrReport
ErrorReporter* DefaultErrorReporter() {
  static StderrReporter* error_reporter = new StderrReporter;
  return error_reporter;
}

TFLite使用android系统输出log,可以理解为把TFLite log输出到Android系统java空间,更精确的描述是把TFLite的错误信息输出到java空间


lite/java/src/main/native/exception_jni.h

 

//是个BufferErrorReporter, 就是把error保存在buffer
class BufferErrorReporter : public tflite::ErrorReporter {
 public:
  BufferErrorReporter(JNIEnv* env, int limit);
  virtual ~BufferErrorReporter();
  int Report(const char* format, va_list args) override;
  const char* CachedErrorMessage();

 private:
  char* buffer_;
  int start_idx_ = 0;
  int end_idx_ = 0;
};

BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) {
  buffer_ = new char[limit];// 根据java传入的大小,分配memory
  if (!buffer_) {
    throwException(env, kNullPointerException,
                   "Malloc of BufferErrorReporter to hold %d char failed.",
                   limit);
    return;
  }
  start_idx_ = 0;//初始化成员变量,用于管理memory
  end_idx_ = limit - 1;
}

//获得存储信息的memory接口
const char* BufferErrorReporter::CachedErrorMessage() { return buffer_; }

//类似pritf, 把args以format的格式保存到上面分配的memory
int BufferErrorReporter::Report(const char* format, va_list args) {
  int size = 0;
  if (start_idx_ < end_idx_) {
    size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args);
  }
  start_idx_ += size;
  return size;
}

怎样把BufferErrorReporter里的内容输出哪?
使用类外的函数 throwException
//该函数的功能是把要输出的内容copy到临时分配的memory,然后调用jni env函数输出
//也许能猜出就是把BufferErrorReporter中buffer的内容最为变量内容输出

void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
  va_list args;
  va_start(args, fmt);
  const size_t max_msg_len = 512;
  auto* message = static_cast<char*>(malloc(max_msg_len));
  if (vsnprintf(message, max_msg_len, fmt, args) >= 0) {
    env->ThrowNew(env->FindClass(clazz), message);
  } else {
    env->ThrowNew(env->FindClass(clazz), "");
  }
  free(message);
  va_end(args);
}

//如下面的例子:
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter: 
  if (status != kTfLiteOk) {
    throwException(env, kIllegalArgumentException,
                   "Cannot create interpreter: %s",
                   error_reporter->CachedErrorMessage());
  }

可以看出TFLite log(android)输出log的方法是把log先保存到BufferErrorReporter, 然后通过
jni函数throwException到java环境把log输出。

现在反过来,何时创建了BufferErrorReporter?


lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
  NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) {
    modelByteBuffer = mappedByteBuffer;
    errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
    modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
    interpreterHandle = createInterpreter(modelHandle, errorHandle);
  }

// createErrorReporter -> native层, 创建后返回给java
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
    JNIEnv* env, jclass clazz, jint size) {
  BufferErrorReporter* error_reporter =
      new BufferErrorReporter(env, static_cast<int>(size));
  return reinterpret_cast<jlong>(error_reporter);
}

//java层会使用从native返回的 errorHandle 创建Model/ Interpreter,
//也就说Model/ Interpreter使用上面的机制向外输出 log.

下面看下TFLite代码的log输入


1. model.cc

class FlatBufferModel {
    private: //成员变量
      ErrorReporter* error_reporter_;
}

//把jni创建的BufferErrorReporter传递到TFLite
createModelWithBuffer(modelByteBuffer, errorHandle)
    --> Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer
    --> tflite::FlatBufferModel::BuildFromBuffer(
      buf, static_cast<size_t>(capacity), error_reporter)
    --> FlatBufferModel(buffer, buffer_size, error_reporter)

FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
                                 ErrorReporter* error_reporter)
    : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) {
  allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
  if (!allocation_->valid()) return;

  model_ = ::tflite::GetModel(allocation_->base());
}

使用方法:
bool FlatBufferModel::CheckModelIdentifier() const {
  if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
    const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
    error_reporter_->Report(
        "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
        ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
    return false;
  }
  return true;
}

2. Interpreter


class Interpreter {
  // The error reporter delegate that tflite will forward queries errors to.
  ErrorReporter* error_reporter_;

  // A pure C data structure used to communicate with the pure C plugin
  // interface. To avoid copying tensor metadata, this is also the definitive
  // structure to store tensors.
  TfLiteContext context_;
}
//interpreter: 可以直接使用成员变量 error_reporter_,

3. TfLiteContext


TFLiteContext的errorReport是通过Interpreter的errorReport实现的;
TfLiteContext/Interpreter相互是数据成员,

typedef struct TfLiteContext {
  // Request that a error be reported with format string msg.
  void (*ReportError)(struct TfLiteContext*, const char* msg, ...);

  // opaque full context ptr (an opaque c++ data structure)
  void* impl_;
}

Interpreter::Interpreter(ErrorReporter* error_reporter)
    : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) {
  context_.impl_ = static_cast<void*>(this);
  context_.ReportError = ReportError;
}

//再看TFLiteContext的 ReportError,它的实现是Interpreter的ReportError
context_.ReportError = ReportError;
context_.impl_ = static_cast<void*>(this);

void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) {
  va_list args;
  va_start(args, format);
  auto* f = static_cast<Interpreter*>(context->impl_);
  // Note here that context->impl_ is recovering the this pointer for an
  // instance of Interpreter to call into the member function ReportErrorImpl
  // (this function is static).
  f->ReportErrorImpl(format, args);
  va_end(args);
}

void Interpreter::ReportErrorImpl(const char* format, va_list args) {
  error_reporter_->Report(format, args);
}

知道TFLite errorReporter的原理,就可以debug,把想打的信息打印出来;


原来以为可以通过throwException的方式,但是app直接crash了
env->ThrowNew(env->FindClass(clazz), message)本来就是抛出异常的,!!

尝试使用DefaultErrorReporter
ErrorReporter* DefaultErrorReporter() {
  static StderrReporter* error_reporter = new StderrReporter;
  return error_reporter;
}

int StderrReporter::Report(const char* format, va_list args) {
  const int result = vfprintf(stderr, format, args);
  fputc('\n', stderr);
  return result;
}

06-26 04:57:12.119  5896  5896 E tflite  : xxx test
06-26 04:57:12.121  5896  5896 E tflite  : xxx test in jni
 

结论

如果想debug TFLite的代码、输出log,可以使用DefaultErrorReporter;

jni:     tflite::DefaultErrorReporter()->Report("xxx test in jni\n");

tflite:  ErrorReporter* testLog = DefaultErrorReporter();
          testLog->Report("wenshuai test\n");

打印tflite文件的tensor 和node

int main(int argc, char* argv[]) {
  if(argc != 2) {                                                                   
    fprintf(stderr, "minimal <tflite model>\n");
    return 1;
  }
  const char* filename = argv[1];

  // Load model
  std::unique_ptr<tflite::FlatBufferModel> model =
      tflite::FlatBufferModel::BuildFromFile(filename);
  TFLITE_MINIMAL_CHECK(model != nullptr);

  // Build the interpreter
  tflite::ops::builtin::BuiltinOpResolver resolver;//需要赋值吗?
  InterpreterBuilder builder(*model.get(), resolver);
  std::unique_ptr<Interpreter> interpreter;
  builder(&interpreter);
  TFLITE_MINIMAL_CHECK(interpreter != nullptr);

  // Allocate tensor buffers.
  TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
  printf("=== Pre-invoke Interpreter State ===\n");
  tflite::PrintInterpreterState(interpreter.get()); //interpreter.get()返回的是指向interpreter的指针

  // Fill input buffers
  // TODO(user): Insert code to fill input tensors

      TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
     tflite::DynamicBuffer buf;
      buf.AddString(sentence.data(), sentence.length());
      buf.WriteToTensor(input);  

  // Run inference
  TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
  printf("\n\n=== Post-invoke Interpreter State ===\n");
  tflite::PrintInterpreterState(interpreter.get());

  // Read output buffers
  // TODO(user): Insert getting data out code.

       TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
      TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);

  return 0;
}

代码路径 lite/optional_debug_tools.cc

  // Prints a dump of what tensors and what nodes are in the interpreter.
  void PrintInterpreterState(Interpreter* interpreter) {
    printf("Interpreter has %zu tensors and %zu nodes\n",
           interpreter->tensors_size(), interpreter->nodes_size());
    printf("Inputs:");
    PrintIntVector(interpreter->inputs());
    printf("Outputs:");
    PrintIntVector(interpreter->outputs());
    printf("\n");
>>  for (int tensor_index = 0; tensor_index < interpreter->tensors_size();
         tensor_index++) {
      TfLiteTensor* tensor = interpreter->tensor(tensor_index);
      printf("Tensor %3d %-20s %10s %15s %10zu bytes (%4.1f MB) ", tensor_index,
             tensor->name, TensorTypeName(tensor->type),
             AllocTypeName(tensor->allocation_type), tensor->bytes,
             (static_cast<float>(tensor->bytes) / (1 << 20)));
      PrintTfLiteIntVector(tensor->dims);
    }
    printf("\n");
>>  for (int node_index = 0; node_index < interpreter->nodes_size();
         node_index++) {
      const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg =
          interpreter->node_and_registration(node_index);
      const TfLiteNode& node = node_and_reg->first;
      const TfLiteRegistration& reg = node_and_reg->second;
      if (reg.custom_name != nullptr) {
        printf("Node %3d Operator Custom Name %s\n", node_index, reg.custom_name);
      } else {
        printf("Node %3d Operator Builtin Code %3d\n", node_index,
               reg.builtin_code);
      }
      printf("  Inputs:");
      PrintTfLiteIntVector(node.inputs);
      printf("  Outputs:");
      PrintTfLiteIntVector(node.outputs);
    }
  }

=== Pre-invoke Interpreter State ===
Interpreter has 90 tensors and 31 nodes
Inputs: 88
Outputs: 87

Tensor   0 MobilenetV1/Logits/AvgPool_1a/AvgPool kTfLiteUInt8  kTfLiteArenaRw       1024 bytes ( 0.0 MB)  1 1 1 1024
Tensor   1 MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd kTfLiteUInt8  kTfLiteArenaRw       1001 bytes ( 0.0 MB)  1 1 1 1001
Tensor   2 MobilenetV1/Logits/Conv2d_1c_1x1/Conv2D_bias kTfLiteInt32   kTfLiteMmapRo       4004 bytes ( 0.0 MB)  1001
Tensor   3 MobilenetV1/Logits/Conv2d_1c_1x1/weights_quant/FakeQuantWithMinMaxVars kTfLiteUInt8   kTfLiteMmapRo    1025024 bytes ( 1.0 MB)  1001 1 1 1024
Tensor   4 MobilenetV1/MobilenetV1/Conv2d_0/Conv2D_Fold_bias kTfLiteInt32   kTfLiteMmapRo        128 bytes ( 0.0 MB)  32
Tensor   5 MobilenetV1/MobilenetV1/Conv2d_0/Relu6 kTfLiteUInt8  kTfLiteArenaRw     401408 bytes ( 0.4 MB)  1 112 112 32
 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值