我是用vscode+platformIO开发,用的platformIO里导入的库,算子里没有REDUCE_PROD,这个算子是在我把模型转换成tflite格式后产生的,没有办法直接在模型上解决,所以决定自己引入这个算子的操作
1. 算子定义和注册
在 tensorflow/lite/micro/micro_mutable_op_resolver.h(在libdeps目录下找,在platform.ini里导入,编译之后才会在.pio目录下生成这个目录),直接复制AddReduceMax()的粘贴在它后面修改成AddReduceProd()
TfLiteStatus AddReduceProd() {
return AddBuiltin(BuiltinOperator_REDUCE_PROD,
tflite::ops::micro::Register_REDUCE_PROD(), ParseReducer);
}
2. REDUCE_PROD算子实现
实现的代码是AI写的我测了一下暂时没问题,建议自己对着TFlite的算子实现,我这里是为了临时解bug记的笔记并不是稳定版本
从AddReduceMax()右键跳转到定义,在 namespace reduce 内,添加新的 EvalProd 实现
TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 1);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TfLiteReducerParams* params =
static_cast<TfLiteReducerParams*>(node->builtin_data);
OpData* op_data = static_cast<OpData*>(node->user_data);
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int* temp_buffer = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
1.0f, // identity for product
[](const float current, const float in) -> float {
return current * in;
}));
break;
case kTfLiteInt8:
// 简化量化版本:仅在 scale 和 zero point 相等的前提下做整数乘积(有溢出/范围问题)。
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
static_cast<double>(op_data->output_scale));
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<int8_t>(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
static_cast<int8_t>(1), // identity
[](const int8_t current, const int8_t in) -> int8_t {
// 注意:直接乘可能 overflow,且量化语义不精确。
int32_t prod = static_cast<int32_t>(current) *
static_cast<int32_t>(in);
// 简单饱和到 int8_t
if (prod > std::numeric_limits<int8_t>::max()) {
prod = std::numeric_limits<int8_t>::max();
} else if (prod < std::numeric_limits<int8_t>::lowest()) {
prod = std::numeric_limits<int8_t>::lowest();
}
return static_cast<int8_t>(prod);
}));
break;
default:
TF_LITE_KERNEL_LOG(context,
"Only float32 and int8 types are supported for PROD.\n");
return kTfLiteError;
}
return kTfLiteOk;
}
在Register_REDUCE_MAX()后添加注册函数
TfLiteRegistration Register_REDUCE_PROD() {
return tflite::micro::RegisterOp(reduce::InitReduce, reduce::PrepareMax,
reduce::EvalProd);
}
在micro_ops.h的TfLiteRegistration Register_REDUCE_MAX();后加上
TfLiteRegistration Register_REDUCE_PROD();
最后在all_ops_resolver.cpp里加上,加在 AddReduceMax();后一行方便快速定位
AddReduceProd();
3. 编译选项
最后我们把整个TensorFlowLite_ESP32包保存备用,后续可以直接在放在lib目录下不会被编译覆盖掉
然后修改platformIO.ini,把lib_deps引入的原先的包注释掉,build_flags里添加
build_flags =
-I lib/TensorFlowLite_ESP32
2811

被折叠的 条评论
为什么被折叠?



