添加自定义的onnx operations
在pyTorch中定义的网络图以及其运算,在onnx中不一定支持,因此,需要自定义的添加operators。如果onnx支持则可以直接使用,一下是支持的网络以及运算:
add (nonzero alpha not supported)
sub (nonzero alpha not supported)
mul
div
cat
mm
addmm
neg
sqrt
tanh
sigmoid
mean
sum
prod
t
expand (only when used before a broadcasting ONNX operator; e.g., add)
transpose
view
split
squeeze
prelu (single weight shared among input channels not supported)
threshold (non-zero threshold/non-zero value not supported)
leaky_relu
glu
softmax (only dim=-1 supported)
avg_pool2d (ceil_mode not supported)
log_softmax
unfold (experimental support with ATen-Caffe2 integration)
elu
concat
abs
index_select
pow
clamp
max
min
eq
gt
lt
ge
le
exp
sin
cos
tan
asin
acos
atan
permute
Conv
BatchNorm
MaxPool1d (ceil_mode not supported)
MaxPool2d (ceil_mode not supported)
MaxPool3d (ceil_mode not supported)
Embedding (no optional arguments supported)
RNN
ConstantPadNd
Dropout
FeatureDropout (training mode not supported)
Index (constant integer and tuple indices supported)
基于以上operators可以实现的一些模型有:
AlexNet
DCGAN
DenseNet
Inception (warning: this model is highly sensitive to changes in operator implementation)
ResNet
SuperResolution
VGG
word_language_model
增加onnx的operation需要接触到pyTorch的源码。
如果增加的operation可以用ATen operation(ATen是pyTorch底层调用的C++ 11库,由pytorch团队开发的)实现,则可以在
torch/csrc/autograd/generated/VariableType.h
中找到他的声明,在torch/onnx/symbolic.py
中添加它,按以下步骤:- 在
torch/onnx/symbolic.py
中定义声明函数,确保函数名与在头文件中VariableType.h
的ATen operation的函数名一样. - 函数中的第一个参数必须为ONNX模型图,如softmax operation的函数名
def softmax(g, input, dim):
第一个参数必须是g
,其他参数名必须同VariableType.h
完全一致.
3.参数的顺序没有强制性要求,一般input
参数为张量类型,然后是其他参数为非张量参数. - 如果输入参数是张量,但是ONNX要求标量,我们必须明确地进行转换。 辅助函数
_scalar
可以将标量张量转换为python标量,_if_scalar_type_as
可以将Python标量转换为PyTorch张量
- 在
如果增加的operation不能用ATen库实现,则需要在相关的pyTorch Function 类中添加声明函数,操作如下:
- 在相关的Function类中创建一个函数,如命名为
symbolic
. - 同样的第一个参数必须是ONNX图
g
. - 其他参数命名必须与
forward
中的名字一致. - 输出的tuple大小必须与
forward
的输出大小一致. - 声明函数应该使用python定义,方法的具体实现使用C++-Python绑定实现,具体接口如下:
- 在相关的Function类中创建一个函数,如命名为
def operator/symbolic(g, *inputs):
"""
Modifies Graph (e.g., using "op"), adding the ONNX operations representing
this PyTorch function, and returning a Value or tuple of Values specifying the
ONNX outputs whose values correspond to the original PyTorch return values
of the autograd Function (or None if an output is not supported by ONNX).
Arguments:
g (Graph): graph to write the ONNX representation into
inputs (Value...): list of values representing the variables which contain
the inputs for this function
"""
class Value(object):
"""Represents an intermediate tensor value computed in ONNX."""
def type(self):
"""Returns the Type of the value."""
class Type(object):
def sizes(self):
"""Returns a tuple of ints representing the shape of a tensor this describes."""
class Graph(object):
def op(self, opname, *inputs, **attrs):
"""
Create an ONNX operator 'opname', taking 'args' as inputs
and attributes 'kwargs' and add it as a node to the current graph,
returning the value representing the single output of this
operator (see the `outputs` keyword argument for multi-return
nodes).
The set of operators and the inputs/attributes they take
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
Arguments:
opname (string): The ONNX operator name, e.g., `Abs` or `Add`.
args (Value...): The inputs to the operator; usually provided
as arguments to the `symbolic` definition.
kwargs: The attributes of the ONNX operator, with keys named
according to the following convention: `alpha_f` indicates
the `alpha` attribute with type `f`. The valid type specifiers are
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
specified with type float accepts either a single float, or a
list of floats (e.g., you would say `dims_i` for a `dims` attribute
that takes a list of integers).
outputs (int, optional): The number of outputs this operator returns;
by default an operator is assumed to return a single output.
If `outputs` is greater than one, this functions returns a tuple
of output `Value`, representing each output of the ONNX operator
in positional.
"""