简介
在PaddlePaddle2.6中,relu6算子在PaddleInference上发生了变化,删除掉了threshold这个Attr,因此我们需要想办法自行适配它。
适配过程
原解析relu6算子的核心代码如下:
void Relu6Mapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
float min = 0.0;
helper_->Clip(input_info[0].name, output_info[0].name, min, threshold_,
input_info[0].dtype);
}
如果仅需要适配PaddlePaddle2.6,只需要改动为(同时还需要在类的构造函数中删除对threshold参数的读取):
void Relu6Mapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
float min = 0.0;
helper_->Clip(input_info[0].name, output_info[0].name, min, 6, input_info[0].dtype);
}
考虑到要兼容PaddlePaddle2.5之前的用户,因此不能直接删除掉threshold这个参数,进一步修改如下:
void Relu6Mapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
float threshold = 6.0;
if (HasAttr("threshold")) {
GetAttr("threshold", &threshold);
}
helper_->Clip(input_info[0].name, output_info[0].name, min, threshold, input_info[0].dtype);
}