【Paddle2ONNX】为Paddle2ONNX适配roll算子

1 简介

Roll算子一般被用再Swin结构中,Paddle2ONNX暂时不支持该算子,本教程介绍如何为Paddle2ONNX添加roll算子。

2 实现过程

2.1 Roll算子简介

paddle.roll(x, shifts, axis=None, name=None)

  • x (Tensor)– 输入的 Tensor。
  • shifts (int|list|tuple) - 滚动位移。如果 shifts 是一个元组或者列表,则 axis 必须是相同大小的元组或者列表,输入 Tensor 将依次沿着每个维度滚动相应的数值。
  • axis (int|list|tuple,可选) – 滚动轴。默认值为 None。
  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

沿着指定维度 axis 对输入 x 进行循环滚动,当元素移动到最后位置时,会从第一个位置重新插入。如果 axis 为 None,则输入在被循环滚动之前,会先展平成 1-D Tensor,滚动操作完成后恢复成原来的形状。

2.2 在Paddle2ONNX中实现roll算子

  1. 首先在paddle2onnx/mapper/tensor下新建roll.hroll.cpp,并在roll.h中添加对Roll算子的定义
#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {
class RollMapper : public Mapper {
    public:
    RollMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
    int64_t op_id)
    : Mapper(p, helper, block_id, op_id) {}
    void Opset7();
};
}  // namespace paddle2onnx
  1. 注册Roll算子
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)
}
  1. 添加对axis为None情况下的Roll算子
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
void RollMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");

  std::vector<int64_t> shifts;
  GetAttr("shifts", &shifts);
  std::vector<int64_t> axis;
  GetAttr("axis", &axis);

  std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
  auto result_name = input_info[0].name;
  if (axis.empty())
  {
    int64_t axes = 0;
    result_name = helper_->Flatten(result_name);
    for(int i = 0;i < shifts.size();i++) {
      auto shift = shifts[i];
      auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
      auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
      temp_node = helper_->MakeNode("Concat", {result_0, result_1});
      AddAttribute(temp_node, "axis", axes);
      result_name = temp_node->output(0);
    }
    helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
    // helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name});
  }
}
}
  1. 添加对axis不为None的实现
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
void RollMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");

  std::vector<int64_t> shifts;
  GetAttr("shifts", &shifts);
  std::vector<int64_t> axis;
  GetAttr("axis", &axis);

  std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
  auto result_name = input_info[0].name;
  if (axis.empty())
  {
  } else {
    for(int i = 0;i < shifts.size();i++) {
      auto shift = shifts[i];
      int64_t axes = axis[i];
      auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
      auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
      if(i+1 == shifts.size()) {
        temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
      } else {
        temp_node = helper_->MakeNode("Concat", {result_0, result_1});
      }
      AddAttribute(temp_node, "axis", axes);
      result_name = temp_node->output(0);
    }
  }
}
}
  1. 合并后核心代码如下:
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)

void RollMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");

  std::vector<int64_t> shifts;
  GetAttr("shifts", &shifts);

  std::vector<int64_t> axis;
  GetAttr("axis", &axis);

  std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
  auto result_name = input_info[0].name;
  if (axis.empty())
  {
    int64_t axes = 0;
    result_name = helper_->Flatten(result_name);
    for(int i = 0;i < shifts.size();i++) {
      auto shift = shifts[i];
      auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
      auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
      temp_node = helper_->MakeNode("Concat", {result_0, result_1});
      AddAttribute(temp_node, "axis", axes);
      result_name = temp_node->output(0);
    }
    helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
    // helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name});
  } else {
    for(int i = 0;i < shifts.size();i++) {
      auto shift = shifts[i];
      int64_t axes = axis[i];
      auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
      auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
      if(i+1 == shifts.size()) {
        temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
      } else {
        temp_node = helper_->MakeNode("Concat", {result_0, result_1});
      }
      AddAttribute(temp_node, "axis", axes);
      result_name = temp_node->output(0);
    }
  }
}
}  // namespace paddle2onnx

3 参考资料

  • 9
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值