1、main函数中获取setting的值
tflite::Lable_image::Main函数:
输入参数全部存储在Setting里面:
./lable_image | ||
-i | ./grace_hopper.bmp | |
-l | ./labels.txt | 用于输出结果的标签有哪些。比如 background tench goldfish great white sharp tiger shark hammerhead。。。文件里有很多当然也可以改成汉字的。最后的输出还会输出相关label的置信度 |
-m | ./mobilet_quant_v1_224.tflite | |
-a | 0 | 是否使用android NNAPI加速【interpreter->UseNNAPI(s->accel);】 |
-c | 1 | 循环次数loop_count【 for (int i = 0; i < s->loop_count; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(FATAL) << "Failed 】 |
-b | 128 | input mean 代码总默认127.5【用于控制收敛速度?】 |
-p | 0 | 是否开启profiling【用于深度学习参数优化 |
-t | 1 | 线程数量【 if (s->number_of_threads != -1) { interpreter->SetNumThreads(s->number_of_threads); }】 |
-v | 1 | 是否显示更多运行信息 |
-s | input std 代码中默认127.5 |
文件中的默认值:
external/tensorflow$ vi tensorflow/contrib/lite/examples/label_image/label_image.h +24
-
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H #define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H #include "tensorflow/contrib/lite/string.h" namespace tflite { namespace label_image { struct Settings { bool verbose = false; bool accel = false; bool input_floating = false; int loop_count = 1; float input_mean = 127.5f; float input_std = 127.5f; string model_name = "./mobilenet_quant_v1_224.tflite"; string input_bmp_name = "./grace_hopper.bmp"; string labels_file_name = "./labels.txt"; string input_layer_type = "uint8_t"; int number_of_threads = 4; }; } // namespace label_image } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
在main函数中对应的解释:
-
static struct option long_options[] = { {"accelerated", required_argument, 0, 'a'}, {"count", required_argument, 0, 'c'}, {"verbose", required_argument, 0, 'v'}, {"image", required_argument, 0, 'i'}, {"labels", required_argument, 0, 'l'}, {"tflite_model", required_argument, 0, 'm'}, {"threads", required_argument, 0, 't'}, {"input_mean", required_argument, 0, 'b'}, {"input_std", required_argument, 0, 's'}, {0, 0, 0, 0}};
2、RunInference(&s);
首先flatbufferbuilder,在modle.h里面包含了两个builder,一个是FlatBufferBuilder,一个是InterpreterBuilder
之前以为FlatBufferBuilder是用来构建只读模型,InterpreterBuilder是用来构建可修改的模型,实际上这样的理解是不对的(从注释上看),FlatBufferBuilder是用来构建tflite的模型,InterpreterBuilder是用来构建interpreter
-
// An RAII object that represents a read-only tflite model, copied from disk, // or mmapped. This uses flatbuffers as the serialization format. // flatbuffers是什么来着,好像是一种固定格式的文件,具体有点忘记了 class FlatBufferModel { public: // Builds a model based on a file. Returns a nullptr in case of failure. static std::unique_ptr<FlatBufferModel> BuildFromFile( const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); // Builds a model based on a pre-loaded flatbuffer. The caller retains // ownership of the buffer and should keep it alive until the returned object // is destroyed. Returns a nullptr in case of failure. static std::unique_ptr<FlatBufferModel> BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter = DefaultErrorReporter()); // Builds a model directly from a flatbuffer pointer. The caller retains // ownership of the buffer and should keep it alive until the returned object // is destroyed. Returns a nullptr in case of failure. static std::unique_ptr<FlatBufferModel> BuildFromModel( const tflite::Model* model_spec, ErrorReporter* error_reporter = DefaultErrorReporter()); // Releases memory or unmaps mmaped meory. ~FlatBufferModel(); // Copying or assignment is disallowed to simplify ownership semantics. FlatBufferModel(const FlatBufferModel&) = delete; FlatBufferModel& operator=(const FlatBufferModel&) = delete; bool initialized() const { return model_ != nullptr; } const tflite::Model* operator->() const { return model_; } const tflite::Model* GetModel() const { return model_; } ErrorReporter* error_reporter() const { return error_reporter_; } const Allocation* allocation() const { return allocation_; } // Returns true if the model identifier is correct (otherwise false and // reports an error). bool CheckModelIdentifier() const; private: // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be // used. explicit FlatBufferModel( const char* filename, bool mmap_file = true, ErrorReporter* error_reporter = DefaultErrorReporter(), bool use_nnapi = false); // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has // to remain alive and unchanged until the end of this flatbuffermodel's // lifetime. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be // used. FlatBufferModel(const char* ptr, size_t num_bytes, ErrorReporter* error_reporter = DefaultErrorReporter()); // Loads a model from Model flatbuffer. The `model` has to remain alive and // unchanged until the end of this flatbuffermodel's lifetime. FlatBufferModel(const Model* model, ErrorReporter* error_reporter); // Flatbuffer traverser pointer. (Model* is a pointer that is within the // allocated memory of the data allocated by allocation's internals. const tflite::Model* model_ = nullptr; ErrorReporter* error_reporter_; Allocation* allocation_ = nullptr; };
InterpreterBuilder
-
// Build an interpreter capable(能力) of interpreting `model`.建立一个能够解析模型的解析器 // // model: a scoped(作用域) model whose lifetime must be at least as long as // the interpreter. In principle multiple interpreters can be made from // a single model. // op_resolver: An instance(实例) that implements(实现) the Resolver(分解器) interface which maps // custom op names and builtin op codes to op registrations. // reportError: a functor that is called to report errors that handles // printf var arg semantics(语意). The lifetime of the reportError object must // be greater than or equal to the Interpreter created by operator(). // // Returns a kTfLiteOk when successful and sets interpreter to a valid // Interpreter. Note: the user must ensure the model lifetime is at least as // long as interpreter's lifetime. class InterpreterBuilder { public: InterpreterBuilder(const FlatBufferModel& model, const OpResolver& op_resolver); // Builds an interpreter given only the raw flatbuffer Model object (instead // of a FlatBufferModel). Mostly used for testing. // If `error_reporter` is null, then DefaultErrorReporter() is used. InterpreterBuilder(const ::tflite::Model* model, const OpResolver& op_resolver, ErrorReporter* error_reporter = DefaultErrorReporter()); InterpreterBuilder(const InterpreterBuilder&) = delete; InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter); TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter, int num_threads); private: TfLiteStatus BuildLocalIndexToRegistrationMapping(); TfLiteStatus ParseNodes( const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, Interpreter* interpreter); TfLiteStatus ParseTensors( const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, Interpreter* interpreter); const ::tflite::Model* model_; const OpResolver& op_resolver_; ErrorReporter* error_reporter_; std::vector<TfLiteRegistration*> flatbuffer_op_index_to_registration_; std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_; const Allocation* allocation_ = nullptr; };
just like this below:
-
#ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_ #define TENSORFLOW_CONTRIB_LITE_MODEL_H_ #include <memory> #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" namespace tflite { class FlatBufferModel{ ************* }; class InterpreterBuilder{ **************** }; } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_MODEL_H_
然后是Interpreter,也就是我们的解释器,翻译官,在interperter.h这个头文件中,class里面的factor很多
-
// Interpreter实际上是翻译官 class Interpreter { public: // Instantiate an interpreter. All errors associated with reading and // processing this model will be forwarded to the error_reporter object. // // Note, if error_reporter is nullptr, then a default StderrReporter is // used. explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); ~Interpreter(); Interpreter(const Interpreter&) = delete; Interpreter& operator=(const Interpreter&) = delete; // Functions to build interpreter // Provide a list of tensor indexes that are inputs to the model. // Each index is bound check and this modifies the consistent_ flag of the // interpreter. TfLiteStatus SetInputs(std::vector<int> inputs); // Provide a list of tensor indexes that are outputs to the model // Each index is bound check and this modifies the consistent_ flag of the // interpreter. TfLiteStatus SetOutputs(std::vector<int> outputs); // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of // `builtin_data` and destroy it with `free`. Ownership of 'init_data' // remains with the caller. TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs, const std::vector<int>& outputs, const char* init_data, size_t init_data_size, void* builtin_data, const TfLiteRegistration* registration, int* node_index = nullptr); // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. // The value pointed to by `first_new_tensor_index` will be set to the // index of the first new tensor if `first_new_tensor_index` is non-null. TfLiteStatus AddTensors(int tensors_to_add, int* first_new_tensor_index = nullptr); // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const std::vector<int>& dims, TfLiteQuantizationParams quantization, const char* buffer, size_t bytes, const Allocation* allocation = nullptr); // Set description of inputs/outputs/data/fptrs for node `node_index`. // This variant assumes an external buffer has been allocated of size // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const std::vector<int>& dims, TfLiteQuantizationParams quantization); // Functions to access tensor data // Read only access to list of inputs. const std::vector<int>& inputs() const { return inputs_; } // Return the name of a given input. The given index must be between 0 and // inputs().size(). const char* GetInputName(int index) const { return context_.tensors[inputs_[index]].name; } // Read only access to list of outputs. const std::vector<int>& outputs() const { return outputs_; } // Return the name of a given output. The given index must be between 0 and // outputs().size(). const char* GetOutputName(int index) const { return context_.tensors[outputs_[index]].name; } // Return the number of tensors in the model. int tensors_size() const { return context_.tensors_size; } // Return the number of ops in the model. int nodes_size() const { return nodes_and_registration_.size(); } // WARNING: Experimental interface, subject to change const std::vector<int>& execution_plan() const { return execution_plan_; } // WARNING: Experimental interface, subject to change // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan); // Get a tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { if (tensor_index >= context_.tensors_size || tensor_index < 0) return nullptr; return &context_.tensors[tensor_index]; } // Get an immutable tensor data structure. const TfLiteTensor* tensor(int tensor_index) const { if (tensor_index >= context_.tensors_size || tensor_index < 0) return nullptr; return &context_.tensors[tensor_index]; } // Get a pointer to an operation and registration data structure if in bounds. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) return nullptr; return &nodes_and_registration_[node_index]; } // Perform a checked cast to the appropriate tensor type. template <class T> T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { if (tensor_ptr->type == typeToTfLiteType<T>()) { return reinterpret_cast<T*>(tensor_ptr->data.raw); } } return nullptr; } // Return a pointer into the data of a given input tensor. The given index // must be between 0 and inputs().size(). template <class T> T* typed_input_tensor(int index) { return typed_tensor<T>(inputs_[index]); } // Return a pointer into the data of a given output tensor. The given index // must be between 0 and outputs().size(). template <class T> T* typed_output_tensor(int index) { return typed_tensor<T>(outputs_[index]); } // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. // TODO(aselle): Consider implementing ArraySlice equivalent to make this // more adept at accepting data without an extra copy. Use absl::ArraySlice // if our partners determine that dependency is acceptable. TfLiteStatus ResizeInputTensor(int tensor_index, const std::vector<int>& dims); // Update allocations for all tensors. This will redim dependent tensors using // the input tensor dimensionality as given. This is relatively expensive. // If you know that your sizes are not changing, you need not call this. // Returns status of success or failure. TfLiteStatus AllocateTensors(); // Invoke the interpreter (run the whole graph in dependency(依赖) order). // // NOTE: It is possible that the interpreter is not in a ready state // to evaluate (i.e. if a ResizeTensor() has been performed without an // AllocateTensors(). // Returns status of success or failure. TfLiteStatus Invoke(); //调用,感觉这个是个最重点的函数 // Enable or disable the NN API (true to enable) void UseNNAPI(bool enable); // Set the number of threads available to the interpreter. void SetNumThreads(int num_threads); // Allow a delegate to look at the graph and modify the graph to handle // parts of the graph themselves. After this is called, the graph may // contain new nodes that replace 1 more nodes. TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); // Retrieve an operator's description of its work, for profiling purposes. const char* OpProfilingString(const TfLiteRegistration& op_reg, const TfLiteNode* node) const { // haili TODO: //if (op_reg.profiling_string == nullptr) return nullptr; //return op_reg.profiling_string(&context_, node); return nullptr; } void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; } profiling::Profiler* GetProfiler() { return profiler_; } private: // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, size_t length) { if (op_reg.init == nullptr) return nullptr; return op_reg.init(&context_, buffer, length); } // Let 'op_reg' release any memory it might have allocated via 'OpInit'. void OpFree(const TfLiteRegistration& op_reg, void* buffer) { if (op_reg.free == nullptr) return; if (buffer) { op_reg.free(&context_, buffer); } } // Prepare the given 'node' for execution. TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) { if (op_reg.prepare == nullptr) return kTfLiteOk; return op_reg.prepare(&context_, node); } // Invoke the operator represented by 'node'. TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) { if (op_reg.invoke == nullptr) return kTfLiteError; return op_reg.invoke(&context_, node); } // Call OpPrepare() for as many ops as possible, allocating memory for their // tensors. If an op containing dynamic tensors is found, preparation will be // postponed until this function is called again. This allows the interpreter // to wait until Invoke() to resolve the sizes of dynamic tensors. TfLiteStatus PrepareOpsAndTensors(); // Call OpPrepare() for all ops starting at 'first_node'. Stop when a // dynamic tensors is found or all ops have been prepared. Fill // 'last_node_prepared' with the id of the op containing dynamic tensors, or // the last in the graph. TfLiteStatus PrepareOpsStartingAt(int first_execution_plan_index, int* last_execution_plan_index_prepared); // Tensors needed by the interpreter. Use `AddTensors` to add more blank // tensor entries. Note, `tensors_.data()` needs to be synchronized to the // `context_` whenever this std::vector is reallocated. Currently this // only happens in `AddTensors()`. std::vector<TfLiteTensor> tensors_; // Check if an array of tensor indices are valid with respect to the Tensor // array. // NOTE: this changes consistent_ to be false if indices are out of bounds. TfLiteStatus CheckTensorIndices(const char* label, const int* indices, int length); // Compute the number of bytes required to represent a tensor with dimensions // specified by the array dims (of length dims_size). Returns the status code // and bytes. TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, size_t* bytes); // Request an tensor be resized implementation. If the given tensor is of // type kTfLiteDynamic it will also be allocated new memory. TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size); // Report a detailed error string (will be printed to stderr). // TODO(aselle): allow user of class to provide alternative destinations. void ReportErrorImpl(const char* format, va_list args); // Entry point for C node plugin API to request an tensor be resized. static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, TfLiteIntArray* new_size); // Entry point for C node plugin API to report an error. static void ReportError(TfLiteContext* context, const char* format, ...); // Entry point for C node plugin API to add new tensors. static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add, int* first_new_tensor_index); // WARNING: This is an experimental API and subject to change. // Entry point for C API ReplaceSubgraphsWithDelegateKernels static TfLiteStatus ReplaceSubgraphsWithDelegateKernels( TfLiteContext* context, TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace); // Update the execution graph to replace some of the nodes with stub // nodes. Specifically any node index that has `nodes[index]==1` will be // slated for replacement with a delegate kernel specified by registration. // WARNING: This is an experimental interface that is subject to change. TfLiteStatus ReplaceSubgraphsWithDelegateKernels( TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace); // WARNING: This is an experimental interface that is subject to change. // Gets the internal pointer to a TensorFlow lite node by node_index. TfLiteStatus GetNodeAndRegistration(int node_index, TfLiteNode** node, TfLiteRegistration** registration); // WARNING: This is an experimental interface that is subject to change. // Entry point for C node plugin API to get a node by index. static TfLiteStatus GetNodeAndRegistration(struct TfLiteContext*, int node_index, TfLiteNode** node, TfLiteRegistration** registration); // WARNING: This is an experimental interface that is subject to change. // Gets an TfLiteIntArray* representing the execution plan. The caller owns // this memory and must free it with TfLiteIntArrayFree(). TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan); // WARNING: This is an experimental interface that is subject to change. // Entry point for C node plugin API to get the execution plan static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. TfLiteContext context_; // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores // function pointers to actual implementation. std::vector<std::pair<TfLiteNode, TfLiteRegistration>> nodes_and_registration_; // Whether the model is consistent. That is to say if the inputs and outputs // of every node and the global inputs and outputs are valid indexes into // the tensor array. bool consistent_ = true; // Whether the model is safe to invoke (if any errors occurred this // will be false). bool invokable_ = false; // Array of indices representing the tensors that are inputs to the // interpreter. std::vector<int> inputs_; // Array of indices representing the tensors that are outputs to the // interpreter. std::vector<int> outputs_; // The error reporter delegate that tflite will forward queries errors to. ErrorReporter* error_reporter_; // Index of the next node to prepare. // During Invoke(), Interpreter will allocate input tensors first, which are // known to be fixed size. Then it will allocate outputs from nodes as many // as possible. When there is a node that produces dynamic sized tensor. // Intepreter will stop allocating tensors, set the value of next allocate // node id, and execute the node to generate the output tensor before continue // to allocate successors. This process repeats until all nodes are executed. // NOTE: this relies on the order of nodes that is in topological order. int next_execution_plan_index_to_prepare_; // WARNING: This is an experimental interface that is subject to change. // This is a list of node indices (to index into nodes_and_registration). // This represents a valid topological sort (dependency ordered) execution // plan. In particular, it is valid for this ordering to contain only a // subset of the node indices. std::vector<int> execution_plan_; // In the future, we'd like a TfLiteIntArray compatible representation. // TODO(aselle): replace execution_plan_ with this. std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> plan_cache_; // Whether to delegate to NN API std::unique_ptr<NNAPIDelegate> nnapi_delegate_; std::unique_ptr<MemoryPlanner> memory_planner_; // Profiler for this interpreter instance. profiling::Profiler* profiler_; };
构建OpResolver
-
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ #include <unordered_map> #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/model.h" namespace tflite { namespace ops { namespace builtin { //OpResolver 是父类 class BuiltinOpResolver : public OpResolver {//OpResolver负责维护函数和指针之间的对应关系 public: BuiltinOpResolver(); TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; TfLiteRegistration* FindOp(const char* op) const override; void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); void AddCustom(const char* name, TfLiteRegistration* registration); private: struct BuiltinOperatorHasher { size_t operator()(const tflite::BuiltinOperator& x) const { return std::hash<size_t>()(static_cast<size_t>(x)); } }; std::unordered_map<tflite::BuiltinOperator, TfLiteRegistration*, BuiltinOperatorHasher> builtins_; std::unordered_map<std::string, TfLiteRegistration*> custom_ops_; }; } // namespace builtin } // namespace ops } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
完整的RunInference函数如下:
-
double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } void RunInference(Settings* s) { if (!s->model_name.c_str()) { LOG(ERROR) << "no model file name\n"; exit(-1); } std::unique_ptr<tflite::FlatBufferModel> model; std::unique_ptr<tflite::Interpreter> interpreter; // 1、建立模型 /* public: // Builds a model based on a file. Returns a nullptr in case of failure. static std::unique_ptr<FlatBufferModel> BuildFromFile( const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); */ model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str()); if (!model) { LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n"; exit(-1); } LOG(INFO) << "Loaded model " << s->model_name << "\n"; /* ErrorReporter* error_reporter() const { return error_reporter_; }*/ model->error_reporter(); LOG(INFO) << "resolved reporter\n"; //2)建立OpResolver 用于指向每个node的操作函数 tflite::ops::builtin::BuiltinOpResolver resolver; tflite::ops::builtin::BuiltinOpResolver resolver; //3)建立解释器 tflite::InterpreterBuilder(*model, resolver)(&interpreter); /* // Builds an interpreter given only the raw flatbuffer Model object (instead // of a FlatBufferModel). Mostly used for testing. // If `error_reporter` is null, then DefaultErrorReporter() is used. InterpreterBuilder(const ::tflite::Model* model, const OpResolver& op_resolver, ErrorReporter* error_reporter = DefaultErrorReporter()); 传入的第二个参数是引用,实际上有好几个构造函数,maybe this is true or not */ // 构建之后生成的是class Interpreter tflite::InterpreterBuilder(*model, resolver)(&interpreter); // 后面这样的操作可能是将interperter赋值给他,我去,忘得差不多了 if (!interpreter) { LOG(FATAL) << "Failed to construct interpreter\n"; exit(-1); } //4)对解释器进行参数设置包括 interpreter->UseNNAPI(s->accel); // 具体可以看class Interpreter里剩下的函数 if (s->verbose) { LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n"; LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n"; LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n"; LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n"; int t_size = interpreter->tensors_size(); for (int i = 0; i < t_size; i++) { // tensor()是TFliteTensor的格式 // 模型中的tensor会被加载成TFliteTensor的格式 if (interpreter->tensor(i)->name) LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", " << interpreter->tensor(i)->bytes << ", " << interpreter->tensor(i)->type << ", " << interpreter->tensor(i)->params.scale << ", " << interpreter->tensor(i)->params.zero_point << "\n"; } } if (s->number_of_threads != -1) { interpreter->SetNumThreads(s->number_of_threads); } // 5)bmp文件读入并进行必要的resize int image_width = 224; int image_height = 224; int image_channels = 3; // examples/label_image/bitmap_helpers.cc 可以借鉴一下 uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, &image_channels, s); // 为什么只取第一个数据呢? int input = interpreter->inputs()[0]; if (s->verbose) LOG(INFO) << "input: " << input << "\n"; /* // Array of indices representing the tensors that are inputs to the // interpreter. std::vector<int> inputs_; // Array of indices representing the tensors that are outputs to the // interpreter. std::vector<int> outputs_; */ const std::vector<int> inputs = interpreter->inputs(); const std::vector<int> outputs = interpreter->outputs(); if (s->verbose) { LOG(INFO) << "number of inputs: " << inputs.size() << "\n"; LOG(INFO) << "number of outputs: " << outputs.size() << "\n"; } /* // Returns status of success or failure. TfLiteStatus AllocateTensors(); */ if (interpreter->AllocateTensors() != kTfLiteOk) { LOG(FATAL) << "Failed to allocate tensors!"; } //打印运行参数相关信息 //optional_debug_tools.cc +72 if (s->verbose) PrintInterpreterState(interpreter.get()); // get input dimension from the input tensor metadata // assuming one input only /* // Fixed size list of integers. Used for dimensions and inputs/outputs tensor // indices typedef struct { int size; // gcc 6.1+ have a bug where flexible members aren't properly handled // https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c #if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ __GNUC_MINOR__ >= 1 int data[0]; #else int data[]; #endif } TfLiteIntArray; */ TfLiteIntArray* dims = interpreter->tensor(input)->dims; int wanted_height = dims->data[1]; int wanted_width = dims->data[2]; int wanted_channels = dims->data[3]; // 大胆假设是将这些数据都转换成tensor指定的type的类型 switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; resize<float>(interpreter->typed_tensor<float>(input), in, image_height, image_width, image_channels, wanted_height, wanted_width, wanted_channels, s); break; case kTfLiteUInt8: resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in, image_height, image_width, image_channels, wanted_height,wanted_width, wanted_channels, s); break; default: LOG(FATAL) << "cannot handle input type " << interpreter->tensor(input)->type << " yet"; exit(-1); } struct timeval start_time, stop_time; gettimeofday(&start_time, NULL); //运行模型及获得运行时间 for (int i = 0; i < s->loop_count; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(FATAL) << "Failed to invoke tflite!\n"; } } gettimeofday(&stop_time, NULL); LOG(INFO) << "invoked \n"; LOG(INFO) << "average time: " << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) << " ms \n"; const int output_size = 1000; const size_t num_results = 5; const float threshold = 0.001f; std::vector<std::pair<float, int>> top_results; // 为什么也是取第一个数据呢? int output = interpreter->outputs()[0]; //获取输出,和上面类似,格式化输出数据的类型 switch (interpreter->tensor(output)->type) { case kTfLiteFloat32: get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size, num_results, threshold, &top_results, true); break; case kTfLiteUInt8: get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0), output_size, num_results, threshold, &top_results, false); break; default: LOG(FATAL) << "cannot handle output type " << interpreter->tensor(input)->type << " yet"; exit(-1); } //加载label并显示对应输出结果 std::vector<string> labels; size_t label_count; //vi examples/label_image/label_image.cc +52 // 读取标签文件 if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk) exit(-1); // first是float的数据 // secound是int的数据 for (const auto& result : top_results) { const float confidence = result.first; const int index = result.second; LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n"; } }
读取lable文件是自己定义:
-
// Takes a file name, and loads a list of labels from it, one per line, and // returns a vector of the strings. It pads with empty strings so the length // of the result is a multiple of 16, because our model expects that. TfLiteStatus ReadLabelsFile(const string& file_name, std::vector<string>* result, size_t* found_label_count) { std::ifstream file(file_name); if (!file) { LOG(FATAL) << "Labels file " << file_name << " not found\n"; return kTfLiteError; } result->clear(); string line; while (std::getline(file, line)) { result->push_back(line); } *found_label_count = result->size(); const int padding = 16; while (result->size() % padding) { result->emplace_back(); } return kTfLiteOk; }