【踩坑随笔】TensorFlowLite_ESP32库中不包含REDUCE_PROD算子,手动移植

我是用vscode+platformIO开发,用的platformIO里导入的库,算子里没有REDUCE_PROD,这个算子是在我把模型转换成tflite格式后产生的,没有办法直接在模型上解决,所以决定自己引入这个算子的操作

1. 算子定义和注册

tensorflow/lite/micro/micro_mutable_op_resolver.h(在libdeps目录下找,在platform.ini里导入,编译之后才会在.pio目录下生成这个目录),直接复制AddReduceMax()的粘贴在它后面修改成AddReduceProd()


  TfLiteStatus AddReduceProd() {
    return AddBuiltin(BuiltinOperator_REDUCE_PROD,
                      tflite::ops::micro::Register_REDUCE_PROD(), ParseReducer);
  }

2. REDUCE_PROD算子实现

实现的代码是AI写的我测了一下暂时没问题,建议自己对着TFlite的算子实现,我这里是为了临时解bug记的笔记并不是稳定版本
AddReduceMax()右键跳转到定义,在 namespace reduce 内,添加新的 EvalProd 实现

TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
  const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
  const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
  TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
  TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
  TfLiteReducerParams* params =
      static_cast<TfLiteReducerParams*>(node->builtin_data);
  OpData* op_data = static_cast<OpData*>(node->user_data);

  int num_axis = static_cast<int>(ElementCount(*axis->dims));
  int* temp_buffer = static_cast<int*>(
      context->GetScratchBuffer(context, op_data->temp_buffer_idx));
  int* resolved_axis = static_cast<int*>(
      context->GetScratchBuffer(context, op_data->resolved_axis_idx));

  switch (input->type) {
    case kTfLiteFloat32:
      TF_LITE_ENSURE(
          context,
          reference_ops::ReduceGeneric<float>(
              tflite::micro::GetTensorData<float>(input), input->dims->data,
              input->dims->size, tflite::micro::GetTensorData<float>(output),
              output->dims->data, output->dims->size,
              tflite::micro::GetTensorData<int>(axis), num_axis,
              params->keep_dims, temp_buffer, resolved_axis,
              1.0f,  // identity for product
              [](const float current, const float in) -> float {
                return current * in;
              }));
      break;
    case kTfLiteInt8:
      // 简化量化版本:仅在 scale 和 zero point 相等的前提下做整数乘积(有溢出/范围问题)。
      TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
                        static_cast<double>(op_data->output_scale));
      TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
      TF_LITE_ENSURE(
          context,
          reference_ops::ReduceGeneric<int8_t>(
              tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
              input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
              output->dims->data, output->dims->size,
              tflite::micro::GetTensorData<int>(axis), num_axis,
              params->keep_dims, temp_buffer, resolved_axis,
              static_cast<int8_t>(1),  // identity
              [](const int8_t current, const int8_t in) -> int8_t {
                // 注意:直接乘可能 overflow,且量化语义不精确。
                int32_t prod = static_cast<int32_t>(current) *
                               static_cast<int32_t>(in);
                // 简单饱和到 int8_t
                if (prod > std::numeric_limits<int8_t>::max()) {
                  prod = std::numeric_limits<int8_t>::max();
                } else if (prod < std::numeric_limits<int8_t>::lowest()) {
                  prod = std::numeric_limits<int8_t>::lowest();
                }
                return static_cast<int8_t>(prod);
              }));
      break;
    default:
      TF_LITE_KERNEL_LOG(context,
                         "Only float32 and int8 types are supported for PROD.\n");
      return kTfLiteError;
  }
  return kTfLiteOk;
}

Register_REDUCE_MAX()后添加注册函数

TfLiteRegistration Register_REDUCE_PROD() {
  return tflite::micro::RegisterOp(reduce::InitReduce, reduce::PrepareMax,
                                   reduce::EvalProd);
}

micro_ops.hTfLiteRegistration Register_REDUCE_MAX();后加上

TfLiteRegistration Register_REDUCE_PROD();

最后在all_ops_resolver.cpp里加上,加在 AddReduceMax();后一行方便快速定位

  AddReduceProd();

3. 编译选项

最后我们把整个TensorFlowLite_ESP32包保存备用,后续可以直接在放在lib目录下不会被编译覆盖掉
然后修改platformIO.ini,把lib_deps引入的原先的包注释掉,build_flags里添加

build_flags = 
	-I lib/TensorFlowLite_ESP32
2022 / 01/ 30: 新版esptool 刷micropython固件指令是 esptool.py cmd... 而是 esptool cmd... 即可;另外rshell 在 >= python 3.10 的时候出错解决方法可以查看:  已于2022年发布的: 第二章:修复rshell在python3.10出错 免费内容: https://edu.csdn.net/course/detail/29666 2025/07/07: 由于该视频在2019年制作,当时py3.7;现在py3.13 由于pyreadline冲突rshell已能用;如果仍要使用rshell请安装py3.12并用我修改的rshell: https://github.com/gamefunc/rshell/releases micropython语法和python3一样,编写起来非常方便。如果你快速入门单片机物联网而且像轻松实现各种功能,那绝力推荐使用micropython。方便易懂易学。 同时如果你懂C语音,也可以用C写好函数并编译进micropython固件里然后进入micropython调用(非必须)。 能通过WIFI联网(2.1章),也能通过sim卡使用2G/3G/4G/5G联网(4.5章)。 为实现语音控制,本教程会教大家使用tensorflow利用神经网络训练自己的语音模型并应用。为实现通过网页控制,本教程会教大家linux(debian10 nginx->uwsgi->python3->postgresql)网站前后台入门。为记录单片机传输过来的数据, 本教程会教大家入门数据。  本教程会通过通俗易懂的比喻来讲解各种原理与思路,并手把手编写程序来实现各项功能。 本教程micropython版本是 2019年6月发布的1.11; 更多内容请看视频列表。  学习这门课程之前你需要至少掌握: 1: python3基础(变量, 循环, 函数, 常用, 常用方法)。 本视频使用到的零件与淘宝上大致价格:     1: 超声波传感器(3)     2: MAX9814麦克风放大模块(8)     3: DHT22(15)     4: LED(0.1)     5: 8路5V低电平触发继电器(12)     6: HX1838红外接收模块(2)     7:红外发射管(0.1),HX1838红外接收板(1)     other: 电表, 排线, 面包板(2)*2,ESP32(28)  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

RIKI_1

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值