【Paddle2ONNX】为Paddle2ONNX修复elementwise_floordiv算子计算错误的问题

文章讨论了PaddlePaddle框架中elementwise_floordiv算子在转换为ONNX时的问题,原代码将其视为普通除法而非整除,导致测试失败。作者提供了修改后的代码,确保在int类型下进行整除并添加了地板操作。
摘要由CSDN通过智能技术生成

简介

elementwise_floordiv 算子在int32/int64的情况下直接转换成了ONNX中的div算子,由于div算子是普通除操作,而不是整除操作,因此无法通过CI的校验。

实现过程

原核心实现代码如下

void ElementWiseFloordivMapper::Opset7() {
    auto input_x_info = GetInput("X");
    auto input_y_info = GetInput("Y");
    auto output_info = GetOutput("Out");

    bool is_int = false;
    if (input_x_info[0].dtype <= 3 || input_x_info[0].dtype == 20 ||
        input_y_info[0].dtype <= 3 || input_y_info[0].dtype == 20) {
        is_int = true;
    }
    if (axis_ == -1 || axis_ == input_x_info[0].Rank() - 1 ||
        input_x_info[0].Rank() == input_y_info[0].Rank()) {
        if (is_int) {
            helper_->MakeNode("Div", {input_x_info[0].name, input_y_info[0].name},
                {output_info[0].name});
        } else {
            auto div_node = helper_->MakeNode(
            "Div", {input_x_info[0].name, input_y_info[0].name});
            helper_->MakeNode("Floor", {div_node->output(0)}, {output_info[0].name});
        }
    } else {
        std::vector<int64_t> broadcast_shape;
        broadcast_shape.resize(axis_ + input_x_info[0].Rank(), 1);
        for (auto i = 0; i < input_y_info[0].Rank(); ++i) {
            broadcast_shape[axis_ + i] = input_y_info[0].shape[i];
        }
        std::string broadcast_shape_node =
        helper_->Constant(GetOnnxDtype(P2ODataType::INT64), broadcast_shape);
        auto y_node = helper_->MakeNode(
        "Reshape", {input_y_info[0].name, broadcast_shape_node});
        if (is_int) {
            helper_->MakeNode("Div", {input_x_info[0].name, y_node->output(0)},
                {output_info[0].name});
        } else {
            auto div_node =
            helper_->MakeNode("Div", {input_x_info[0].name, y_node->output(0)});
            helper_->MakeNode("Floor", {div_node->output(0)}, {output_info[0].name});
        }
    }
}

可以看到,针对int的情况,原转换函数直接将elementwise_floordiv 算子转换成了Div算子,这显然缺少了一个floor操作,因此修改为如下代码:

void ElementWiseFloordivMapper::Opset7() {
  auto input_x_info = GetInput("X");
  auto input_y_info = GetInput("Y");
  auto output_info = GetOutput("Out");

  auto div_input_0 = helper_->AutoCast(input_x_info[0].name, input_x_info[0].dtype, P2ODataType::FP32);
  auto div_input_1 = helper_->AutoCast(input_y_info[0].name, input_y_info[0].dtype, P2ODataType::FP32);

 if (axis_ == -1 || axis_ == input_x_info[0].Rank() - 1 || input_x_info[0].Rank() == input_y_info[0].Rank()) {
    auto div_node = helper_->MakeNode("Div", {div_input_0, div_input_1});
    auto floor_output = helper_->MakeNode("Floor", {div_node->output(0)});
    helper_->AutoCast(floor_output->output(0), output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
  } else {
    std::vector<int64_t> broadcast_shape;
    broadcast_shape.resize(axis_ + input_x_info[0].Rank(), 1);
    for (auto i = 0; i < input_y_info[0].Rank(); ++i) {
      broadcast_shape[axis_ + i] = input_y_info[0].shape[i];
    }
    std::string broadcast_shape_node = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), broadcast_shape);
    auto y_node = helper_->MakeNode("Reshape", {div_input_1, broadcast_shape_node});
    auto div_node = helper_->MakeNode("Div", {div_input_0, y_node->output(0)});
    auto floor_output = helper_->MakeNode("Floor", {div_node->output(0)});
    helper_->AutoCast(floor_output->output(0), output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
  }
}

参考资料

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值