pytorch转onnx, onnx 12 中没有hardswish opt

在onnx opset 12下转以下模型时因不支持hardswish激活函数而报错

  • GhostNet
  • MobileNetv3Small
  • EfficientNetLite0
  • PP-LCNet
    解决方案是找到对应的nn.Hardswish层,将其替换为自己覆写的Hardswish实现:
class Hardswish(nn.Module):  # export-friendly version of nn.Hardswish()
    @staticmethod
    def forward(x):
        # return x * F.hardsigmoid(x)  # for torchscript and CoreML
        return x * F.hardtanh(x + 3, 0., 6.) / 6.  # for torchscript, CoreML and ONNX

PP-LCNet为例,找到哪些层是Hardswish层,替换方法为

# 替换函数, 参考https://zhuanlan.zhihu.com/p/356273702
def _set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)
for k, m in model.named_modules():
	if 'dw_sp.2' in k or 'dw_sp.6' in k:
		_set_module(model, k, Hardswish())

当然也可以根据m来判断是否为nn.Hardswish的实例,

for k, m in model.named_modules():
	if isinstance(m, nn.Hardswish):
		_set_module(model, k, Hardswish())

参考

YOLOv5-Multibackbone-Compression
Pytorch替换model对象任意层的方法

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值