TFLite: 代码结构(C and Core)

C

 c_api_internal.h

// This file defines a C API for implementing operations in tflite.
// These operations can be defined using c++ but the interface between
// the interpreter and the operations are C.
//
// Summary of abstractions
// TF_LITE_ENSURE - Self-sufficient error checking
// TfLiteStatus - Status reporting
// TfLiteIntArray - stores tensor shapes (dims),
// TfLiteContext - allows an op to access the tensors
// TfLiteTensor - tensor (a multidimensional array)
// TfLiteNode - a single node or operation
// TfLiteRegistration - the implementation of a conceptual operation.

builtin_op_data.h

各种operator的参数

 

core

定义了3个抽象接口,ErrorReport/ BuiltinDataAllocator/ OpResolver

 

演示怎么使用ErrorReport, 怎样从Op中的得到op参数,怎样分配BuiltinData memory保存得到的op参数 

  class MockErrorReporter : public ErrorReporter {
   public:
    MockErrorReporter() : buffer_size_(0) {}
    int Report(const char* format, va_list args) override {
      buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);//把log保存到分配的buffer中
      return buffer_size_;
    }
    char* GetBuffer() { return buffer_; }
    int GetBufferSize() { return buffer_size_; }
  
   private:
    static constexpr int kBufferSize = 256;
    char buffer_[kBufferSize];
    int buffer_size_;
  };
  
  // Used to determine how the op data parsing function creates its working space.
  class MockDataAllocator : public BuiltinDataAllocator {
   public:
    MockDataAllocator() : is_allocated_(false) {}
    void* Allocate(size_t size) override {
      EXPECT_FALSE(is_allocated_);
      const int max_size = kBufferSize;
      EXPECT_LE(size, max_size);
      is_allocated_ = true;
      return buffer_;
    }
  void Deallocate(void* data) override { is_allocated_ = false; }
   private:
    static constexpr int kBufferSize = 1024;
    char buffer_[kBufferSize];
    bool is_allocated_;
  };

上面创建了class: MockErrorReporter and MockDataAllocator

  TEST(FlatbufferConversions, TestParseOpDataConv) {
    MockErrorReporter mock_reporter;
    ErrorReporter* reporter = &mock_reporter; //基类指针指向派生类
    MockDataAllocator mock_allocator;
    

   //把Operator打包成flatbuffer协议的格式:其中包含 operator参数
    flatbuffers::FlatBufferBuilder builder;
    flatbuffers::Offset<void> conv_options = //使用schema_generated.h中online函数生成conv参数
        CreateConv2DOptions(builder, Padding_SAME, 1, 2,
                            ActivationFunctionType_RELU, 3, 4).Union();
    flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect(
        builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options,
        nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr);
    builder.Finish(conv_offset);
    void* conv_pointer = builder.GetBufferPointer();
    const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);

   //从flatbuffer格式中解析出op 参数并保存到mock_allocator,分配的memory: output_data里
    void* output_data = nullptr;
    EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
                                     &mock_allocator, &output_data));
    EXPECT_NE(nullptr, output_data);
    TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
    EXPECT_EQ(kTfLitePaddingSame, params->padding);
    EXPECT_EQ(1, params->stride_width);
    EXPECT_EQ(2, params->stride_height);
    EXPECT_EQ(kTfLiteActRelu, params->activation);
    EXPECT_EQ(3, params->dilation_width_factor);
    EXPECT_EQ(4, params->dilation_height_factor);
  }

验证通过BuiltinOperator或者custum_name得到从OpResolver中找到TfLiteRegistration

定义TfLiteRegistration、OpResolver

void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
    // Do nothing.                                                                                                                                                                                          
    return nullptr;
  }
void MockFree(TfLiteContext* context, void* buffer) {
    // Do nothing.
  }
TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
    return kTfLiteOk;
  }
TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
    return kTfLiteOk;
  }

  class MockOpResolver : public OpResolver {
   public:
    const TfLiteRegistration* FindOp(BuiltinOperator op,
                                   int version) const override {
      if (op == BuiltinOperator_CONV_2D) {
        static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
                                     MockInvoke};
        return &r;
      } else {
        return nullptr;
      }
    }
 const TfLiteRegistration* FindOp(const char* op, int version) const override {
      if (strcmp(op, "mock_custom") == 0) {
        static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
                                     MockInvoke};
        return &r;
      } else {
        return nullptr;
      }
    }                                                                                                                                                                                                
  };

  TEST(OpResolver, TestResolver) {
    MockOpResolver mock_resolver;
    OpResolver* resolver = &mock_resolver;
  
    const TfLiteRegistration* registration =
        resolver->FindOp(BuiltinOperator_CONV_2D, 0);                                                                                                                                                       
    EXPECT_NE(nullptr, registration);
    EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
    EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
    EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));

}

 验证通过flatbuffer格式文件中解析出op code,并从op code查找resolve得到TfLiteRegistration 

  TEST(OpResolver, TestGetRegistrationFromOpCodeConv) {
    MockOpResolver mock_resolver;
    OpResolver* resolver = &mock_resolver;
    MockErrorReporter mock_reporter;
    ErrorReporter* reporter = &mock_reporter;
                                                                                                                                                                                                            
    flatbuffers::FlatBufferBuilder builder;
    flatbuffers::Offset<OperatorCode> conv_offset =
        CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0);
    builder.Finish(conv_offset);


    void* conv_pointer = builder.GetBufferPointer();
    const OperatorCode* conv_code =
        flatbuffers::GetRoot<OperatorCode>(conv_pointer);
    const TfLiteRegistration* registration = nullptr;
    EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
                                                   &registration));

    EXPECT_NE(nullptr, registration);
    EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
    EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
    EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
    EXPECT_EQ(0, mock_reporter.GetBufferSize());
  }

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值