C
c_api_internal.h
// This file defines a C API for implementing operations in tflite.
// These operations can be defined using c++ but the interface between
// the interpreter and the operations are C.
//
// Summary of abstractions
// TF_LITE_ENSURE - Self-sufficient error checking
// TfLiteStatus - Status reporting
// TfLiteIntArray - stores tensor shapes (dims),
// TfLiteContext - allows an op to access the tensors
// TfLiteTensor - tensor (a multidimensional array)
// TfLiteNode - a single node or operation
// TfLiteRegistration - the implementation of a conceptual operation.
builtin_op_data.h
各种operator的参数
core
定义了3个抽象接口,ErrorReport/ BuiltinDataAllocator/ OpResolver
演示怎么使用ErrorReport, 怎样从Op中的得到op参数,怎样分配BuiltinData memory保存得到的op参数
class MockErrorReporter : public ErrorReporter {
public:
MockErrorReporter() : buffer_size_(0) {}
int Report(const char* format, va_list args) override {
buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);//把log保存到分配的buffer中
return buffer_size_;
}
char* GetBuffer() { return buffer_; }
int GetBufferSize() { return buffer_size_; }
private:
static constexpr int kBufferSize = 256;
char buffer_[kBufferSize];
int buffer_size_;
};
// Used to determine how the op data parsing function creates its working space.
class MockDataAllocator : public BuiltinDataAllocator {
public:
MockDataAllocator() : is_allocated_(false) {}
void* Allocate(size_t size) override {
EXPECT_FALSE(is_allocated_);
const int max_size = kBufferSize;
EXPECT_LE(size, max_size);
is_allocated_ = true;
return buffer_;
}
void Deallocate(void* data) override { is_allocated_ = false; }
private:
static constexpr int kBufferSize = 1024;
char buffer_[kBufferSize];
bool is_allocated_;
};
上面创建了class: MockErrorReporter and MockDataAllocator
TEST(FlatbufferConversions, TestParseOpDataConv) {
MockErrorReporter mock_reporter;
ErrorReporter* reporter = &mock_reporter; //基类指针指向派生类
MockDataAllocator mock_allocator;
//把Operator打包成flatbuffer协议的格式:其中包含 operator参数
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<void> conv_options = //使用schema_generated.h中online函数生成conv参数
CreateConv2DOptions(builder, Padding_SAME, 1, 2,
ActivationFunctionType_RELU, 3, 4).Union();
flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect(
builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options,
nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr);
builder.Finish(conv_offset);
void* conv_pointer = builder.GetBufferPointer();
const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
//从flatbuffer格式中解析出op 参数并保存到mock_allocator,分配的memory: output_data里
void* output_data = nullptr;
EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
&mock_allocator, &output_data));
EXPECT_NE(nullptr, output_data);
TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
EXPECT_EQ(kTfLitePaddingSame, params->padding);
EXPECT_EQ(1, params->stride_width);
EXPECT_EQ(2, params->stride_height);
EXPECT_EQ(kTfLiteActRelu, params->activation);
EXPECT_EQ(3, params->dilation_width_factor);
EXPECT_EQ(4, params->dilation_height_factor);
}
验证通过BuiltinOperator或者custum_name得到从OpResolver中找到TfLiteRegistration
定义TfLiteRegistration、OpResolver
void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
// Do nothing.
return nullptr;
}
void MockFree(TfLiteContext* context, void* buffer) {
// Do nothing.
}
TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
class MockOpResolver : public OpResolver {
public:
const TfLiteRegistration* FindOp(BuiltinOperator op,
int version) const override {
if (op == BuiltinOperator_CONV_2D) {
static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
MockInvoke};
return &r;
} else {
return nullptr;
}
}
const TfLiteRegistration* FindOp(const char* op, int version) const override {
if (strcmp(op, "mock_custom") == 0) {
static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
MockInvoke};
return &r;
} else {
return nullptr;
}
}
};
TEST(OpResolver, TestResolver) {
MockOpResolver mock_resolver;
OpResolver* resolver = &mock_resolver;
const TfLiteRegistration* registration =
resolver->FindOp(BuiltinOperator_CONV_2D, 0);
EXPECT_NE(nullptr, registration);
EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
}
验证通过flatbuffer格式文件中解析出op code,并从op code查找resolve得到TfLiteRegistration
TEST(OpResolver, TestGetRegistrationFromOpCodeConv) {
MockOpResolver mock_resolver;
OpResolver* resolver = &mock_resolver;
MockErrorReporter mock_reporter;
ErrorReporter* reporter = &mock_reporter;
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<OperatorCode> conv_offset =
CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0);
builder.Finish(conv_offset);
void* conv_pointer = builder.GetBufferPointer();
const OperatorCode* conv_code =
flatbuffers::GetRoot<OperatorCode>(conv_pointer);
const TfLiteRegistration* registration = nullptr;
EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
®istration));
EXPECT_NE(nullptr, registration);
EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
EXPECT_EQ(0, mock_reporter.GetBufferSize());
}