pyTorch onnx 学习(二)

2 篇文章 0 订阅
2 篇文章 0 订阅

添加自定义的onnx operations

在pyTorch中定义的网络图以及其运算,在onnx中不一定支持,因此,需要自定义的添加operators。如果onnx支持则可以直接使用,一下是支持的网络以及运算:

add (nonzero alpha not supported)
sub (nonzero alpha not supported)
mul
div
cat
mm
addmm
neg
sqrt
tanh
sigmoid
mean
sum
prod
t
expand (only when used before a broadcasting ONNX operator; e.g., add)
transpose
view
split
squeeze
prelu (single weight shared among input channels not supported)
threshold (non-zero threshold/non-zero value not supported)
leaky_relu
glu
softmax (only dim=-1 supported)
avg_pool2d (ceil_mode not supported)
log_softmax
unfold (experimental support with ATen-Caffe2 integration)
elu
concat
abs
index_select
pow
clamp
max
min
eq
gt
lt
ge
le
exp
sin
cos
tan
asin
acos
atan
permute
Conv
BatchNorm
MaxPool1d (ceil_mode not supported)
MaxPool2d (ceil_mode not supported)
MaxPool3d (ceil_mode not supported)
Embedding (no optional arguments supported)
RNN
ConstantPadNd
Dropout
FeatureDropout (training mode not supported)
Index (constant integer and tuple indices supported)

基于以上operators可以实现的一些模型有:

AlexNet
DCGAN
DenseNet
Inception (warning: this model is highly sensitive to changes in operator implementation)
ResNet
SuperResolution
VGG
word_language_model

增加onnx的operation需要接触到pyTorch的源码。

  • 如果增加的operation可以用ATen operation(ATen是pyTorch底层调用的C++ 11库,由pytorch团队开发的)实现,则可以在torch/csrc/autograd/generated/VariableType.h中找到他的声明,在torch/onnx/symbolic.py中添加它,按以下步骤:

    1. torch/onnx/symbolic.py中定义声明函数,确保函数名与在头文件中VariableType.h的ATen operation的函数名一样.
    2. 函数中的第一个参数必须为ONNX模型图,如softmax operation的函数名def softmax(g, input, dim):第一个参数必须是g,其他参数名必须同VariableType.h完全一致.
      3.参数的顺序没有强制性要求,一般input参数为张量类型,然后是其他参数为非张量参数.
    3. 如果输入参数是张量,但是ONNX要求标量,我们必须明确地进行转换。 辅助函数_scalar可以将标量张量转换为python标量,_if_scalar_type_as可以将Python标量转换为PyTorch张量
  • 如果增加的operation不能用ATen库实现,则需要在相关的pyTorch Function 类中添加声明函数,操作如下:

    1. 在相关的Function类中创建一个函数,如命名为symbolic.
    2. 同样的第一个参数必须是ONNX图g.
    3. 其他参数命名必须与forward中的名字一致.
    4. 输出的tuple大小必须与forward的输出大小一致.
    5. 声明函数应该使用python定义,方法的具体实现使用C++-Python绑定实现,具体接口如下:
def operator/symbolic(g, *inputs):
  """
  Modifies Graph (e.g., using "op"), adding the ONNX operations representing
  this PyTorch function, and returning a Value or tuple of Values specifying the
  ONNX outputs whose values correspond to the original PyTorch return values
  of the autograd Function (or None if an output is not supported by ONNX).

  Arguments:
    g (Graph): graph to write the ONNX representation into
    inputs (Value...): list of values representing the variables which contain
        the inputs for this function
  """

class Value(object):
  """Represents an intermediate tensor value computed in ONNX."""
  def type(self):
    """Returns the Type of the value."""

class Type(object):
  def sizes(self):
    """Returns a tuple of ints representing the shape of a tensor this describes."""

class Graph(object):
  def op(self, opname, *inputs, **attrs):
    """
    Create an ONNX operator 'opname', taking 'args' as inputs
    and attributes 'kwargs' and add it as a node to the current graph,
    returning the value representing the single output of this
    operator (see the `outputs` keyword argument for multi-return
    nodes).

    The set of operators and the inputs/attributes they take
    is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md

    Arguments:
        opname (string): The ONNX operator name, e.g., `Abs` or `Add`.
        args (Value...): The inputs to the operator; usually provided
            as arguments to the `symbolic` definition.
        kwargs: The attributes of the ONNX operator, with keys named
            according to the following convention: `alpha_f` indicates
            the `alpha` attribute with type `f`.  The valid type specifiers are
            `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
            specified with type float accepts either a single float, or a
            list of floats (e.g., you would say `dims_i` for a `dims` attribute
            that takes a list of integers).
        outputs (int, optional):  The number of outputs this operator returns;
            by default an operator is assumed to return a single output.
            If `outputs` is greater than one, this functions returns a tuple
            of output `Value`, representing each output of the ONNX operator
            in positional.
    """
  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值