TfLite: TensorFlow模型格式和Post-training quantization

TensorFlow的模型格式

TensorFlow的模型格式有很多种,针对不同场景可以使用不同的格式,只要符合规范的模型都可以轻易部署到在线服务或移动设备上,这里简单列举一下。

  • Checkpoint: 用于保存模型的权重,主要用于模型训练过程中参数的备份和模型训练热启动。
  • GraphDef:用于保存模型的Graph,不包含模型权重,加上checkpoint后就有模型上线的全部信息。
  • SavedModel:使用saved_model接口导出的模型文件,包含模型Graph和权限可直接用于上线,TensorFlow和Keras模型推荐使用这种模型格式。
  • FrozenGraph:使用freeze_graph.py对checkpoint和GraphDef进行整合和优化,可以直接部署到Android、iOS等移动设备上。
  • TFLite:基于flatbuf对模型进行优化,可以直接部署到Android、iOS等移动设备上,使用接口和FrozenGraph有些差异

TensorFlow的模型格式有以上几种,由不同工具生成,有不同的用途。使用tensorlfow底层API和keras的方式不同,但这些格式和是否为keras没有关系。SavedModel和FrozenGraph是两个不同的格式

 

 Saving and Serializing Models with TensorFlow Keras

https://www.tensorflow.org/beta/guide/keras/saving_and_serializing

 Whole-model saving

# Save the model
model.save('path_to_my_model.h5')

# Recreate the exact same model purely from the file
new_model = keras.models.load_model('path_to_my_model.h5')

model保存为h5格式,从model文件加载model

Export to SavedModel

You can also export a whole model to the TensorFlow SavedModel format. SavedModel is a standalone serialization format for Tensorflow objects, supported by TensorFlow serving as well as TensorFlow implementations other than Python. 

# Export the model to a SavedModel
keras.experimental.export_saved_model(model, 'path_to_saved_model')

# Recreate the exact same model
new_model = keras.experimental.load_from_saved_model('path_to_saved_model')

# Check that the state is preserved
new_predictions = new_model.predict(x_test)

 Architecture-only saving

 

config = model.get_config()

You can alternatively use to_json() from from_json(), which uses a JSON string to store the config instead of a Python dict. This is useful to save the config to disk.

json_config = model.to_json()

 Weights-only saving

weights = model.get_weights()  # Retrieves the state of the model.
model.set_weights(weights)  # Sets the state of the model.

 Model Optimization 的发展历史

最初提出(只对weight量化)

https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3

Introducing the Model Optimization Toolkit for TensorFlow

也就是post-training quantization via “hybrid operations”, 混合数据类型运算(只对weight量化)

如下面kernel/filter数据类型是kTfLiteUInt8而输出等是kTfLiteFloat32,所以是 hybrid

Tensor   1 img                  kTfLiteFloat32  kTfLiteArenaRw       3136 bytes ( 0.0 MB)  1 784
Tensor   2 mnist_model/dense/MatMtranspose kTfLiteUInt8   kTfLiteMmapRo      50176 bytes ( 0.0 MB)  64 784
Tensor   3 mnist_model/dense/MatMul_bias kTfLiteFloat32   kTfLiteMmapRo        256 bytes ( 0.0 MB)  64
Tensor   4 mnist_model/dense/Relu kTfLiteFloat32  kTfLiteArenaRw        256 bytes ( 0.0 MB)  1 64
 

当前最新(对weight和activation量化)

https://medium.com/tensorflow/tensorflow-model-optimization-toolkit-post-training-integer-quantization-b4964a1ea9ba

TensorFlow Model Optimization Toolkit — Post-Training Integer Quantization(后面提供了教程链接)

后上面的区别,需要输入样本数据集representative_dataset,

def representative_dataset_gen():
  data = tfds.load(...)

  for _ in range(num_calibration_steps):
    image, = data.take(1)
    yield [image]

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = tf.lite.RepresentativeDataset(
    representative_dataset_gen)

量化后遇到的问题

并不是量化后的模型文件就能执行推断成功,还有看算子的对量化的支持实现。

full_connect对量化的支持

tflite 支持 hybrid混合运算

  TfLiteRegistration* Register_FULLY_CONNECTED() {
    return Register_FULLY_CONNECTED_PIE();
 }

  TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
    static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free,
                                   fully_connected::Prepare,
                                   fully_connected::Eval<fully_connected::kPie>};
    return &r; 
  }

  template <KernelType kernel_type>
  TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    auto* params =
        reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
    OpData* data = reinterpret_cast<OpData*>(node->user_data);
  
    const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
    const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
    TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  
    switch (filter->type) {  // Already know in/out types are same.
      case kTfLiteUInt8:
    if (params->weights_format ==
                   kTfLiteFullyConnectedWeightsFormatDefault) {
          printf("EvalQuantized<kernel_type>\n");
          return EvalQuantized<kernel_type>(context, node, params, data, input,
                                            filter, bias, output);
        } 
 }

  template <KernelType kernel_type>
  TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
                             TfLiteFullyConnectedParams* params, OpData* data,
                             const TfLiteTensor* input,
                             const TfLiteTensor* filter, const TfLiteTensor* bias,
                             TfLiteTensor* output) {
    gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
  
    int32_t input_offset = -input->params.zero_point;
    int32_t filter_offset = -filter->params.zero_point;
    int32_t output_offset = output->params.zero_point;

    if (kernel_type == kPie && input->type == kTfLiteFloat32) {
      printf("kPie:\n");
      // Pie currently only supports quantized models and float inputs/outputs.
      TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
      TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
      return EvalHybrid(context, node, params, data, input, filter, bias,
                        input_quantized, scaling_factors, output);
    }
  }

Tflite for mcu就不支持hybrid,最后异常

  TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    auto* params =
        reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
  
    const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
    const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
    TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
  
    TfLiteType data_type = input->type;
    OpData local_data_object;
    OpData* data = &local_data_object;
    TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
                                          filter, bias, output, data));
  
    switch (filter->type) {  // Already know in/out types are same.
      case kTfLiteFloat32:
        return EvalFloat(context, node, params, data, input, filter, bias,
                         output);
      case kTfLiteUInt8:
        return EvalQuantized(context, node, params, data, input, filter, bias,
                             output);

 }

  TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
                             TfLiteFullyConnectedParams* params, OpData* data,
                             const TfLiteTensor* input,
                             const TfLiteTensor* filter, const TfLiteTensor* bias,
                             TfLiteTensor* output) {
    const int32_t input_offset = -input->params.zero_point;
    const int32_t filter_offset = -filter->params.zero_point;
    const int32_t output_offset = output->params.zero_point;
  
    tflite::FullyConnectedParams op_params;
    op_params.input_offset = input_offset;                                                                      
    op_params.weights_offset = filter_offset;
    op_params.output_offset = output_offset;
    op_params.output_multiplier = data->output_multiplier;
    // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
    op_params.output_shift = -data->output_shift;
    op_params.quantized_activation_min = data->output_activation_min;
    op_params.quantized_activation_max = data->output_activation_max;
  
  #define TF_LITE_FULLY_CONNECTED(output_data_type)                      \
    reference_ops::FullyConnected(                                       \
        op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
        GetTensorShape(filter), GetTensorData<uint8_t>(filter),          \
        GetTensorShape(bias), GetTensorData<int32_t>(bias),              \
        GetTensorShape(output), GetTensorData<output_data_type>(output), \
        nullptr)
    switch (output->type) {// float类型
      case kTfLiteUI t8: 
        TF_LITE_FULLY_CONNECTED(uint8_t);
        break;
      case kTfLiteInt16:
        TF_LITE_FULLY_CONNECTED(int16_t);
        break;
      default:
        printf("output type: %d\n", output->type);
        context->ReportError(
            context,
            "Quantized FullyConnected expects output data type uint8 or int16");
  }
 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值