1 简介
Roll算子一般被用再Swin结构中,Paddle2ONNX暂时不支持该算子,本教程介绍如何为Paddle2ONNX添加roll算子。
2 实现过程
2.1 Roll算子简介
paddle.roll(x, shifts, axis=None, name=None)
- x (Tensor)– 输入的 Tensor。
- shifts (int|list|tuple) - 滚动位移。如果 shifts 是一个元组或者列表,则 axis 必须是相同大小的元组或者列表,输入 Tensor 将依次沿着每个维度滚动相应的数值。
- axis (int|list|tuple,可选) – 滚动轴。默认值为 None。
- name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。
沿着指定维度 axis 对输入 x 进行循环滚动,当元素移动到最后位置时,会从第一个位置重新插入。如果 axis 为 None,则输入在被循环滚动之前,会先展平成 1-D Tensor,滚动操作完成后恢复成原来的形状。
2.2 在Paddle2ONNX中实现roll算子
- 首先在paddle2onnx/mapper/tensor下新建roll.h和roll.cpp,并在roll.h中添加对Roll算子的定义
#pragma once
#include <string>
#include <vector>
#include "paddle2onnx/mapper/mapper.h"
namespace paddle2onnx {
class RollMapper : public Mapper {
public:
RollMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
void Opset7();
};
} // namespace paddle2onnx
- 注册Roll算子
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"
namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)
}
- 添加对axis为None情况下的Roll算子
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"
namespace paddle2onnx {
void RollMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
std::vector<int64_t> shifts;
GetAttr("shifts", &shifts);
std::vector<int64_t> axis;
GetAttr("axis", &axis);
std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
auto result_name = input_info[0].name;
if (axis.empty())
{
int64_t axes = 0;
result_name = helper_->Flatten(result_name);
for(int i = 0;i < shifts.size();i++) {
auto shift = shifts[i];
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
AddAttribute(temp_node, "axis", axes);
result_name = temp_node->output(0);
}
helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
// helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name});
}
}
}
- 添加对axis不为None的实现
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"
namespace paddle2onnx {
void RollMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
std::vector<int64_t> shifts;
GetAttr("shifts", &shifts);
std::vector<int64_t> axis;
GetAttr("axis", &axis);
std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
auto result_name = input_info[0].name;
if (axis.empty())
{
} else {
for(int i = 0;i < shifts.size();i++) {
auto shift = shifts[i];
int64_t axes = axis[i];
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
if(i+1 == shifts.size()) {
temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
} else {
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
}
AddAttribute(temp_node, "axis", axes);
result_name = temp_node->output(0);
}
}
}
}
- 合并后核心代码如下:
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"
namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)
void RollMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
std::vector<int64_t> shifts;
GetAttr("shifts", &shifts);
std::vector<int64_t> axis;
GetAttr("axis", &axis);
std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
auto result_name = input_info[0].name;
if (axis.empty())
{
int64_t axes = 0;
result_name = helper_->Flatten(result_name);
for(int i = 0;i < shifts.size();i++) {
auto shift = shifts[i];
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
AddAttribute(temp_node, "axis", axes);
result_name = temp_node->output(0);
}
helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
// helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name});
} else {
for(int i = 0;i < shifts.size();i++) {
auto shift = shifts[i];
int64_t axes = axis[i];
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
if(i+1 == shifts.size()) {
temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
} else {
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
}
AddAttribute(temp_node, "axis", axes);
result_name = temp_node->output(0);
}
}
}
} // namespace paddle2onnx