pytorch转onnx相关问题

这些问题是在转谱归一化spectral_norm中遇到的。

首先遇到的就是torch.mv算子和torch.dot算子不支持的问题。

目前pytorch已经官方实现了谱归一化:spectral_norm,其中包含了torch.mv、 torch.dot算子,转onnx会出现错误

解决办法:将torch.mv和torch.dot用torch.matmul代替,不过可能需要自己改变一下tensor的维度。(通过unsqueeze之类的)

我再解决了上述两个算子后,能够跑torch.onnx.export函数,但是转换推断的时候会出现:

RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 2 and 1

貌似是维度出现了问题,但是我找了几个小时都没有找到问题所在。

后来解决的办法是,在转onnx之前,除去spectral_norm。

具体参考了:https://github.com/pytorch/pytorch/issues/27723

官方已经实现了如何移除spectral_norm的函数:

def remove_spectral_norm(module, name='weight'):
    r"""Removes the spectral normalization reparameterization from a module.

    Args:
        module (Module): containing module
        name (str, optional): name of weight parameter

    Example:
        >>> m = spectral_norm(nn.Linear(40, 10))
        >>> remove_spectral_norm(m)
    """
    for k, hook in module._forward_pre_hooks.items():
        if isinstance(hook, SpectralNorm) and hook.name == name:
            hook.remove(module)
            del module._forward_pre_hooks[k]
            break
    else:
        raise ValueError("spectral_norm of '{}' not found in {}".format(
            name, module))

    for k, hook in module._state_dict_hooks.items():
        if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
            del module._state_dict_hooks[k]
            break

    for k, hook in module._load_state_dict_pre_hooks.items():
        if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
            del module._load_state_dict_pre_hooks[k]
            break
    return module

具体的步骤是:

1.按照训练的时候构建模型model(此时还是含有spectral_norm),并且装载pretrained model,这个pretrained model中含有spectral_norm的相关参数:weight_orig、weight_u以及weight_v。

2.之后利用以下函数,这个函数的输入是构建的model,完成的是递归模型的结构,当遇见spectral_norm时,会调用上面的remove_spectral_norm移除spectral_norm。

def remove_all_spectral_norm(item):
    if isinstance(item, nn.Module):
        try:
            remove_spectral_norm(item)
        except Exception:
            pass
        
        for child in item.children():  
            remove_all_spectral_norm(child)

    if isinstance(item, nn.ModuleList):
        for module in item:
            remove_all_spectral_norm(module)

    if isinstance(item, nn.Sequential):
        modules = item.children()
        for module in modules:
            remove_all_spectral_norm(module)

3.最后跑torch.onnx.export

这里稍微解释一下,普通的卷积操作后面增加spectral_norm后,训练的参数会从卷积的weight会变为weight_orig、weight_u、weight_v这三类,也就是在保存模型的时候保存的都是这些参数。

通过上述的移除操作,会从weight_orig、weight_u、weight_v恢复出weight。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值