TFLite执行推断的的过程
const char* filename = argv[1];
// Load model
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(filename);
TFLITE_MINIMAL_CHECK(model != nullptr);
// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model.get(), resolver);
std::unique_ptr<Interpreter> interpreter;
builder(&interpreter); //对应TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
TFLITE_MINIMAL_CHECK(interpreter != nullptr);
// Allocate tensor buffers.
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
printf("=== Pre-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(interpreter.get());
// Fill input buffers
// TODO(user): Insert code to fill input tensors
// Run inference
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
printf("\n\n=== Post-invoke Interpreter State ===\n");
tflite::PrintInterpreterState(interpreter.get());
构建interpreter过程中涉及的class: FlatBufferModel, interpreterBuilder, interpreter
InterpreterBuilder::operator()
创建Interpreter的代码在model.cc的operator() 重载 InterpreterBuilder::operator()
TfLiteStatus InterpreterBuilder::operator()( std::unique_ptr<Interpreter>* interpreter, int num_threads) {
if (model_->version() != TFLITE_SCHEMA_VERSION) {
}
if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
auto* subgraphs = model_->subgraphs();
auto* buffers = model_->buffers();
if (subgraphs->size() != 1) {
}
const tflite::SubGraph* subgraph = (*subgraphs)[0];
auto operators = subgraph->operators();
auto tensors = subgraph->tensors();
if (!operators || !tensors || !buffers) {
}
interpreter->reset(new Interpreter(error_reporter_));
if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk)
}
// Set num threads
(**interpreter).SetNumThreads(num_threads);
// Parse inputs/outputs
(**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
(**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
// Finally setup nodes and tensors
if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
std::vector<int> variables;
for (int i = 0; i < (*interpreter)->tensors_size(); ++i) {
auto* tensor = (*interpreter)->tensor(i);
if (tensor->is_variable) {
variables.push_back(i);
}
}
(**interpreter).SetVariables(std::move(variables));
}
1. model_ 从Flatbuffer model中得到 tf::Model
const ::tflite::Model* model_;
FlatBufferModel::FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter)
: error_reporter_(ValidateErrorReporter(error_reporter)) {
allocation_ = allocation;
if (!allocation_->valid() || !CheckModelIdentifier()) return;
model_ = ::tflite::GetModel(allocation_->base());
}
uint32_t version() const {
return GetField<uint32_t>(VT_VERSION, 0);
}
2. BuildLocalIndexToRegistrationMapping()
//根据model文件中的model_->operator_codes(),查询OpResolve库得到TfLiteRegistration
TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
TfLiteStatus status = kTfLiteOk;
auto opcodes = model_->operator_codes();
for (const OperatorCode* opcode : *opcodes) {//这里没有使用Get(i),用的方式不是数组形式
const TfLiteRegistration* registration = nullptr;
status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
®istration);
flatbuffer_op_index_to_registration_.push_back(registration);
}
return status;
}
根据flatbuffer model里保存的OperatorCode,获得 opcode->builtin_code(),调用通用接口
GetRegistrationFromOpCode,其中依赖op_resolver_,这个值根据用户定义变化,如定义BuiltinOpResolver
BuiltinOpResolver的继承关系如下:
2.1 class OpResolver
core/api/op_resolver.h
namespace tflite {
// Abstract interface that returns TfLiteRegistrations given op codes or custom
// op names. This is the mechanism that ops being referenced in the flatbuffer
// model are mapped to executable function pointers (TfLiteRegistrations).
class OpResolver {
public:
// Finds the op registration for a builtin operator by enum code.
virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const = 0;
// Finds the op registration of a custom operator by op name.
virtual const TfLiteRegistration* FindOp(const char* op,
int version) const = 0;
virtual ~OpResolver() {}
};
// Handles the logic for converting between an OperatorCode structure extracted
// from a flatbuffer and information about a registered operator implementation.
TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
const OpResolver& op_resolver,
ErrorReporter* error_reporter,
const TfLiteRegistration** registration);
} // namespace tflite
TfLiteStatus GetRegistrationFromOpCode(
const OperatorCode* opcode, const OpResolver& op_resolver,
ErrorReporter* error_reporter, const TfLiteRegistration** registration) {
TfLiteStatus status = kTfLiteOk;
*registration = nullptr;
auto builtin_code = opcode->builtin_code();
int version = opcode->version();
if (builtin_code != BuiltinOperator_CUSTOM) {
*registration = op_resolver.FindOp(builtin_code, version);
} else if (!opcode->custom_code()) {
//------
} else {
const char* name = opcode->custom_code()->c_str();
*registration = op_resolver.FindOp(name, version);
}
return status;
}
2.2 MutableOpResolver
mutable_op_resolver.h
// An OpResolver that is mutable, also used as the op in gen_op_registration.
// A typical usage:
// MutableOpResolver resolver;
// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
// InterpreterBuilder(model, resolver)(&interpreter);
class MutableOpResolver : public OpResolver {
public:
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const override;
const TfLiteRegistration* FindOp(const char* op, int version) const override;
void AddBuiltin(tflite::BuiltinOperator op,
const TfLiteRegistration* registration, int min_version = 1,
int max_version = 1);
void AddCustom(const char* name, const TfLiteRegistration* registration,
int min_version = 1, int max_version = 1);
void AddAll(const MutableOpResolver& other);
private:
typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
typedef std::pair<std::string, int> CustomOperatorKey;
std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
builtins_;
std::unordered_map<CustomOperatorKey, TfLiteRegistration,
op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
custom_ops_;
};
2.3 kernel/register
kernel/register.h
namespace tflite {
namespace ops {
namespace builtin {
class BuiltinOpResolver : public MutableOpResolver {
public:
BuiltinOpResolver();
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const override;
const TfLiteRegistration* FindOp(const char* op, int version) const override;
};
} // namespace builtin
} // namespace ops
} // namespace tflite
BuiltinOpResolver::BuiltinOpResolver() {//构造函数创建BuiltinOpResolver
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
----
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
}
2.4 如果定制opResolve
比如在TFLite占用的空间太大,需要裁剪;或者需要增加operator
2.4.1 在tflite::ops::builtin::BuiltinOpResolver基础上更改
在register.cc文件中
删除TfLiteRegistration* Register_RNN();
删除AddBuiltin(BuiltinOperator_RNN, Register_RNN());
BUILD文件中删除lstm.cc不参与编译
name = "builtin_op_kernels",
"lstm.cc",
2.4.2 写和register.h 和register.cc对应的类
3. (**interpreter).AddTensors(tensors->Length())
赋值std::vector<TfLiteTensor> tensors_,且初始化为0,
为避免copy,且同步到TfLiteContext:context_,该函数并没有从flatbuffer中取出tensor内容
3.1 std::vector<TfLiteTensor> tensors_
// Tensors needed by the interpreter. Use `AddTensors` to add more blank
// tensor entries. Note, `tensors_.data()` needs to be synchronized to the
// `context_` whenever this std::vector is reallocated. Currently this
// only happens in `AddTensors()`.
std::vector<TfLiteTensor> tensors_;
3.2 AddTensors声明
// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
// The value pointed to by `first_new_tensor_index` will be set to the
// index of the first new tensor if `first_new_tensor_index` is non-null.
TfLiteStatus AddTensors(int tensors_to_add,
int* first_new_tensor_index = nullptr);
3.3 AddTensors实现
TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
int* first_new_tensor_index) {
int base_index = tensors_.size();
if (first_new_tensor_index) *first_new_tensor_index = base_index;
tensors_.resize(tensors_.size() + tensors_to_add);
for (int i = base_index; i < tensors_.size(); i++) {
memset(&tensors_[i], 0, sizeof(tensors_[i]));
tensors_[i].buffer_handle = kTfLiteNullBufferHandle;
}
context_.tensors = tensors_.data();
context_.tensors_size = tensors_.size();
return kTfLiteOk;
}
4. interpreter setInputs/setOutputs
(**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
(**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
4.1 FlatBufferIntArrayToVector
//flatbuffer中指向数组的指针,再通过Get(i)获得数组的元素
std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {//类型转换
// Initialize shape of tensors with null shape. Empty vectors are converted
// to nullptr for models that are constructed via flatbuffers::Pack.
if (flat_array == nullptr) {
return {};
}
std::vector<int> ret(flat_array->Length());
for (int i = 0; i < flat_array->Length(); i++) {
ret[i] = flat_array->Get(i);
}
return ret;
}
4.2 std::move(inputs)
// Array of indices representing the tensors that are inputs to the
// interpreter.[数组的内容是inputs tensors的索引号index]
std::vector<int> inputs_;
TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
TF_LITE_ENSURE_OK(&context_,
CheckTensorIndices("inputs", inputs.data(), inputs.size()));
inputs_ = std::move(inputs);
return kTfLiteOk;
}
4.2.1 CheckTensorIndices
//tensor的索引号和前面AddTensors赋值的context_.tensors_size比较判断
TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
const int* indices, int length) {
// Making sure kOptionalTensor is not re-defined to something other than -1.
static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1");
for (int i = 0; i < length; i++) {
int index = indices[i];
// Continue if index == kOptionalTensor before additional comparisons below,
// size_t(-1) is always >= context_tensors_size.
if (index == kOptionalTensor) {
continue;
}
if (index < 0 || static_cast<size_t>(index) >= context_.tensors_size) {
ReportError(&context_, "Invalid tensor index %d in %s\n", index, label);
consistent_ = false;
return kTfLiteError;
}
}
return kTfLiteOk;
}
5. InterpreterBuilder::ParseNodes
TfLiteStatus InterpreterBuilder::ParseNodes(
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, Interpreter* interpreter) {
TfLiteStatus status = kTfLiteOk;
// Reduce the number of redundant allocations
interpreter->ReserveNodes(operators->Length());
for (int i = 0; i < operators->Length(); ++i) {
const auto* op = operators->Get(i);//又见Get(i)
int index = op->opcode_index();// 这里的index和opResove没有关系,是model_->operator_codes()的index
const TfLiteRegistration* registration = flatbuffer_op_index_to_registration_[index];
BuiltinOperator op_type = static_cast<BuiltinOperator>(registration->builtin_code);
if (op->custom_options()) {
interpreter->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()),
reinterpret_cast<const char*>(op->custom_options()->data()),
op->custom_options()->size(), nullptr, registration);
} else {
void* builtin_data = nullptr;
MallocDataAllocator malloc_allocator;
TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, &malloc_allocator, &builtin_data));
interpreter->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data, registration);
}
}
}
5.1 ReserveNodes
// Ensure the internal node storage memory allocates at least `count`
// spots for node. NOTE, this doesn't actually add operators. This is an
// efficiency optimization that is subject to change.
void ReserveNodes(int count);
void Interpreter::ReserveNodes(int count) {
nodes_and_registration_.reserve(count);
}
// Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
// function pointers to actual implementation.
std::vector<std::pair<TfLiteNode, TfLiteRegistration>>
nodes_and_registration_;
5.2 获得Node对应的处理函数
flatbuffer_op_index_to_registration_[index] 获得 TfLiteRegistration
std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration
依赖上面BuildLocalIndexToRegistrationMapping()建立的flatbuffer_op_index_to_registration_
5.3 获得处理函数的参数
分为保存builtin参数的内存和从flatbuff中得到builtin参数
5.3.1 获得内存
最终使用malloc方法
flatbuffer_conversions.h
// Interface class for builtin data allocations.
class BuiltinDataAllocator {
public:
virtual void* Allocate(size_t size) = 0;
virtual void Deallocate(void* data) = 0;
// Allocate a structure, but make sure it is a POD structure that doesn't
// require constructors to run. The reason we do this, is that Interpreter's C
// extension part will take ownership so destructors will not be run during
// deallocation.
template <typename T>
T* AllocatePOD() {
static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
return static_cast<T*>(this->Allocate(sizeof(T)));
}
virtual ~BuiltinDataAllocator() {}
};
model.cc
// Used to determine how the op data parsing function creates its working space.
class MallocDataAllocator : public BuiltinDataAllocator {
public:
void* Allocate(size_t size) override { return malloc(size); }
void Deallocate(void* data) override { free(data); }
};
5.3.2 ParseOpData 获得 builtin_data
// Parse the appropriate data out of the op.
//
// This handles builtin data explicitly as there are flatbuffer schemas.
// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
// need to be released by calling `free`.`
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
*builtin_data = nullptr;
switch (op_type) {
case BuiltinOperator_CONV_2D: {
TfLiteConvParams* params = allocator->AllocatePOD<TfLiteConvParams>();
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
params->stride_height = conv_params->stride_h();
params->activation =
parse_activation(conv_params->fused_activation_function());
params->dilation_width_factor = conv_params->dilation_w_factor();
params->dilation_height_factor = conv_params->dilation_h_factor();
}
*builtin_data = reinterpret_cast<void*>(params);
break;
}
}
5.4 AddNodeWithParameters
核心操作的是nodes_and_registration_,并execution_plan_.push_back(new_node_index)
TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
const TfLiteRegistration* registration, int* node_index) {
int new_node_index = nodes_and_registration_.size();//主要size()不是长度,是说用了几个
if (node_index) *node_index = new_node_index;
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
//获得一个node_and_reg
auto& node_and_reg = nodes_and_registration_.back();
//first指的是pair中的node,如果其中有数据则Free
TfLiteNode& node = node_and_reg.first;
if (node.inputs) TfLiteIntArrayFree(node.inputs);
if (node.outputs) TfLiteIntArrayFree(node.outputs);
if (node.temporaries) TfLiteIntArrayFree(node.temporaries);
// NOTE, here we are not using move semantics yet, since our internal
// representation isn't std::vector, but in the future we would like to avoid
// copies, so we want the interface to take r-value references now.
node.inputs = ConvertVectorToTfLiteIntArray(inputs);
node.outputs = ConvertVectorToTfLiteIntArray(outputs);
node.temporaries = TfLiteIntArrayCreate(0);
//如果有init_data,则调用OpInit
if (init_data) {
node.user_data = OpInit(*registration, init_data, init_data_size);
} else {
node.user_data = OpInit(*registration,
reinterpret_cast<const char*>(builtin_data_deleter.get()), 0);
}
//builtinData
node.builtin_data = builtin_data_deleter.release();
// TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size`
// properly for nodes generated by ReplaceSubgraphsWithDelegateKernels.
//custom data
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
// When it's a CUSTOM op, the `custom_options` field in the Flatbuffer
// `Operator` table is passed in.
node.custom_initial_data = init_data;
node.custom_initial_data_size = init_data_size;
} else {
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
}
node.delegate = nullptr;
//sencode指pair中的 registration
node_and_reg.second = *registration;
execution_plan_.push_back(new_node_index);
return kTfLiteOk;
}
5.4.1 nodes_and_registration_
// Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
// function pointers to actual implementation.
std::vector<std::pair<TfLiteNode, TfLiteRegistration>>
nodes_and_registration_;
5.4.2 op_reg:OpInit
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
size_t length) {
if (op_reg.init == nullptr) return nullptr;
return op_reg.init(&context_, buffer, length);
}
5.4.3 execution_plan_
// WARNING: This is an experimental interface that is subject to change.
// This is a list of node indices (to index into nodes_and_registration).
// This represents a valid topological sort (dependency ordered) execution
// plan. In particular, it is valid for this ordering to contain only a
// subset of the node indices.
std::vector<int> execution_plan_;
execution_plan_.push_back(new_node_index);
6. ParseTensors(buffers, tensors, interpreter->get())
由index得到 context_.tensors[tensor_index],然后赋值如内存分配类型ro/rw
TfLiteStatus InterpreterBuilder::ParseTensors(
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
Interpreter* interpreter) {
for (int i = 0; i < tensors->Length(); ++i) {
const auto* tensor = tensors->Get(i);//又见Get(i)
std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
TfLiteQuantizationParams quantization;
quantization.scale = 0;
quantization.zero_point = 0;
auto* q_params = tensor->quantization();
if (q_params) {}
TfLiteType type;
if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
kTfLiteOk) {
}
size_t buffer_size = 0;
const char* buffer_ptr;
TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
bool is_variable = tensor->is_variable();
if (buffer_ptr) {
if (is_variable) {
status = kTfLiteError;
}
if (interpreter->SetTensorParametersReadOnly(
i, type, get_name(tensor), dims, quantization, buffer_ptr,
buffer_size, allocation_) != kTfLiteOk) {
status = kTfLiteError;
}
} else {
if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor),
dims, quantization,
is_variable) != kTfLiteOk) {
status = kTfLiteError;
}
}
}
}
6.1 ConvertTensorType(tensor->type(), &type, error_reporter_)
flatbuff model和tflite中tensor类型定义是不同的
Flatbuffer 中的tensor类型
enum TensorType {
TensorType_FLOAT32 = 0,
TensorType_FLOAT16 = 1,
TensorType_INT32 = 2,
TensorType_UINT8 = 3,
TensorType_INT64 = 4,
TensorType_STRING = 5,
TensorType_BOOL = 6,
TensorType_INT16 = 7,
TensorType_COMPLEX64 = 8,
TensorType_MIN = TensorType_FLOAT32,
TensorType_MAX = TensorType_COMPLEX64
};
TFLite中的类型
typedef enum {
kTfLiteNoType = 0,
kTfLiteFloat32 = 1,
kTfLiteInt32 = 2,
kTfLiteUInt8 = 3,
kTfLiteInt64 = 4,
kTfLiteString = 5,
kTfLiteBool = 6,
kTfLiteInt16 = 7,
kTfLiteComplex64 = 8,
} TfLiteType;
6.2 SetTensorParametersReadOnly
// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
inline TfLiteStatus SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization,
const char* buffer, size_t bytes,
const Allocation* allocation = nullptr) {
return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
dims.data(), quantization, buffer, bytes,
allocation);
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
size_t bytes, const Allocation* allocation) {
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
// For most tensors we know exactly how much memory is necessary so we can
// ensure the buffer is large enough. However, we need to skip string tensors
// because their sizes change with the contents of the individual strings.
if (type != kTfLiteString) {
size_t required_bytes;
TF_LITE_ENSURE_OK(&context_,
BytesRequired(type, dims, rank, &required_bytes));
TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
}
TfLiteTensor& tensor = context_.tensors[tensor_index];
if (type == tensor.type &&
EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) {
// Fast path which does not invalidate the invokable property.
TfLiteTensorDataFree(&tensor);
tensor.data.raw = const_cast<char*>(buffer);
if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims);
tensor.params = quantization;
tensor.allocation_type = kTfLiteMmapRo;
tensor.allocation = allocation; //tensor的allocation怎么用?分配内存?
} else {
state_ = kStateUninvokable;
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, false, &tensor);
}
return kTfLiteOk;
}
6.2.1 BytesRequired
TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
size_t dims_size, size_t* bytes) {
size_t count = 1;
for (int k = 0; k < dims_size; k++) count *= dims[k];
switch (type) {
case kTfLiteFloat32:
*bytes = sizeof(float) * count;
break;
}
6.2.2 EqualArrayAndTfLiteIntArray
// Checks whether a `TfLiteIntArray` and an int array have matching elements.
// The caller must guarantee that 'b' has at least 'b_size' elements.
bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
const int* b);
bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
const int* b) {
if (!a) return false;
if (a->size != b_size) return false;
for (int i = 0; i < a->size; ++i) {
if (a->data[i] != b[i]) return false;
}
return true;
}
6.2.3 TfLiteTensorDataFree
// Free data memory of tensor `t`;
void TfLiteTensorDataFree(TfLiteTensor* t);
void TfLiteTensorDataFree(TfLiteTensor* t) {
if (t->allocation_type == kTfLiteDynamic && t->data.raw) {//动态分配的free
free(t->data.raw);
}
t->data.raw = NULL;
}
6.2.4 TfLiteTensorReset
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
tensor->dims = dims;
tensor->params = quantization;
tensor->data.raw = buffer;
tensor->bytes = size;
tensor->allocation_type = allocation_type;
tensor->allocation = allocation;
tensor->is_variable = is_variable;
}
6.3 SetTensorParametersReadWrite
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
size_t required_bytes = 0;
if (type != kTfLiteString) {
TF_LITE_ENSURE_OK(&context_,
BytesRequired(type, dims, rank, &required_bytes));
}
TfLiteAllocationType allocation_type = kTfLiteArenaRw;
if (type == kTfLiteString) {
if (is_variable) {
}
allocation_type = kTfLiteDynamic;
} else if (is_variable) {
allocation_type = kTfLiteArenaRwPersistent;
}
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization,
/*buffer=*/nullptr, required_bytes, allocation_type,
nullptr, is_variable, &context_.tensors[tensor_index]);
return kTfLiteOk;
}
Rw/RwPersisten类型的memory内存分配:tensor数据使用的内存是在这里分配的,而Ro没有分配内存
TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) {
TfLiteTensor& tensor = *graph_info_->tensor(tensor_index);
if (tensor.allocation_type == kTfLiteArenaRw) {
if (allocs_[tensor_index].size != 0) {
TF_LITE_ENSURE_STATUS(arena_.ResolveAlloc(context_, allocs_[tensor_index],
&tensor.data.raw));
}
}
if (tensor.allocation_type == kTfLiteArenaRwPersistent) {
TF_LITE_ENSURE_STATUS(persistent_arena_.ResolveAlloc(
context_, allocs_[tensor_index], &tensor.data.raw));
}
return kTfLiteOk;
}