pytorch padding_Pytorch转Msnhnet模型思路分享

来源 | GiantPandaCV

作者 | BBuf

【导读】这篇文章要为大家介绍一下MsnhNet的模型转换思路,大多数搞CV的小伙伴都知道:珍爱生命,远离模型转换。但是啊,当你想部署训练出来的模型时(如Pytorch训练的),模型转换又是必不可少的步骤,这个时候就真的很愁。为什么调用xxx脚本转换出来的模型报错了?为什么转换出来的模型推理结果不正确呢?所以,我认为在做模型转换时拥有一个清晰的分析思路是不可少的。这篇文章就为大家分享了一下最近开源的前向推理框架MsnhNet是如何将原始的Pytorch模型较为优雅的转换过来,希望我们介绍的思路可以对有模型转换需求的同学带来一定启发。

网络结构的转换

网络结构转换比较复杂,其原因在于涉及到不同的op以及相关的基础操作.

  • 「思路一」: 利用print的结果进行构建
    • 「优点」: 简单易用
    • 「缺点」: 大部分网络,print并不能完全展现出其结构.简单网络可用.
  • 代码实现:
import torchimport torch.nn as nnclass Model(nn.Module):    def __init__(self):        super(Model, self).__init__()        self.conv1 = nn.Conv2d(1, 6, 5)        self.bn1   = nn.BatchNorm2d(6,eps=1e-5,momentum=0.1)        self.relu1 = nn.ReLU()        self.pool1 = nn.MaxPool2d(2)        self.conv2 = nn.Conv2d(6, 16, 5)        self.bn2   = nn.BatchNorm2d(16,eps=1e-5,momentum=0.1)        self.relu2 = nn.ReLU()        self.pool2 = nn.MaxPool2d(2)    def forward(self, x):        y = self.conv1(x)        y = self.bn1(y)        y = self.relu1(y)         y = self.pool1(y)         y = self.conv2(y)        y = self.bn2(y)         y = self.relu2(y)         y = self.pool2(y)        return ynn = Model()print(nn)
  • 结果: 很显然对于用纯nn.Module搭建的网络是可行的
Model(  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  (relu1): ReLU()  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  (relu2): ReLU()  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
  • 如果在forward内添加相关操作,则此方案将无效.
  • 代码实现:
import torchimport torch.nn as nnclass Model(nn.Module):    def __init__(self):        super(Model, self).__init__()        self.conv1 = nn.Conv2d(1, 6, 5)        self.bn1   = nn.BatchNorm2d(6,eps=1e-5,momentum=0.1)        self.relu1 = nn.ReLU()        self.pool1 = nn.MaxPool2d(2)        self.conv2 = nn.Conv2d(6, 16, 5)        self.bn2   = nn.BatchNorm2d(16,eps=1e-5,momentum=0.1)        self.relu2 = nn.ReLU()        self.pool2 = nn.MaxPool2d(2)    def forward(self, x):        y = self.conv1(x)        y = self.bn1(y)        y = self.relu1(y)         y = self.pool1(y)         y = self.conv2(y)        y = self.bn2(y)         y = self.relu2(y)         y = self.pool2(y)        y = torch.flatten(y)        return ynn = Model()print(nn)
  • 结果: 很显然forward内的flatten操作并没有被导出.
Model(  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  (relu1): ReLU()  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  (relu2): ReLU()  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
  • 「思路二」: 通过类似Windows Hook技术.(思路来源pytorch_to_caffe) 在pytorch的Op在执行之前,对此Op进行截取,以获取相关信息,从而实现网络构建.
    • 优点: 几乎可以完成所有pytorch的op导出.
    • 缺点: 实现复杂,容易误操作,可能影响pytorch本身结果错误.
  • 代码实现: 通过构建Hook类, 重写op, 并替换原op操作,获取op的参数. 层的上下关系,通过tensor的_cdata作为唯一识别的ID.
import torchimport torch.nn as nnfrom torchsummary import summaryimport torch.nn.functional as FlogMsg = Trueccc = []# Hook截取类class Hook(object):    hookInited = False    def __init__(self,raw,replace,**kwargs):        self.obj=replace # 被截取之后的op        self.raw=raw # 原op    def __call__(self,*args,**kwargs):        if not Hook.hookInited: #在Hook类未初始化之前,该信号原路返回            return self.raw(*args,**kwargs)        else:                   #否则,则按截取之后,实现的函数执行            out=self.obj(self.raw,*args,**kwargs)            return outdef log(*args):    if logMsg:        print(*args)# 替换原cov2d函数的实现def _conv2d(raw,inData, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):    # 对于上下层网络关系,可以使用tensor的_cdata,该参数类似唯一ID    # 输入tensor的唯一ID    log( "conv2d-i" , inData._cdata)     x=raw(inData,weight,bias,stride,padding,dilation,groups)    ccc.append(x)                    # 此处将输出保存,防止被inplace操作,导致所有tensor的_cdata丧失唯一性    # 此处就可以根据conv2d参数就行网络构建    # msnhnet.buildConv2d(...)     # 输出tensor的唯一ID    log( "conv2d-o" , x._cdata)    return x# 被替换OP                  原OP     自定义OPF.conv2d        =   Hook(F.conv2d,_conv2d)
  • 完整Demo:
import torchimport torch.nn as nnfrom torchsummary import summaryimport torch.nn.functional as FlogMsg = Trueccc = []# Hook截取类class Hook(object):    hookInited = False    def __init__(self,raw,replace,**kwargs):        self.obj=replace # 被截取之后的op        self.raw=raw # 原op    def __call__(self,*args,**kwargs):        if not Hook.hookInited: #在Hook类未初始化之前,该信号原路返回            return self.raw(*args,**kwargs)        else:                   #否则,则按截取之后,实现的函数执行            out=self.obj(self.raw,*args,**kwargs)            return outdef log(*args):    if logMsg:        print(*args)# 替换原cov2d函数的实现def _conv2d(raw,inData, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):    # 对于上下层网络关系,可以使用tensor的_cdata,该参数类似唯一ID    # 输入tensor的唯一ID    log( "conv2d-i" , inData._cdata)     x=raw(inData,weight,bias,stride,padding,dilation,groups)    ccc.append(x)                    # 此处将输出保存,防止被inplace操作,导致所有tensor的_cdata丧失唯一性    # 此处就可以根据conv2d参数就行网络构建    # msnhnet.buildConv2d(...)     # 输出tensor的唯一ID    log( "conv2d-o" , x._cdata)    return xdef _relu(raw, inData, inplace=False):    log( "relu-i" , inData._cdata)    x = raw(inData,False)    ccc.append(x)    log( "relu-o" , x._cdata)    return xdef _batch_norm(raw,inData, running_mean, running_var, weight=None, bias=None,training=False, momentum=0.1, eps=1e-5):    log( "bn-i" , inData._cdata)    x = raw(inData, running_mean, running_var, weight, bias, training, momentum, eps)    ccc.append(x)    log( "bn-o" , x._cdata)    return xdef _flatten(raw,*args):    log( "flatten-i" , args[0]._cdata)    x=raw(*args)    ccc.append(x)    log( "flatten-o" , x._cdata)    return x# 被替换OP                   原OP       自定义OPF.conv2d        =   Hook(F.conv2d,_conv2d)F.batch_norm    =   Hook(F.batch_norm,_batch_norm)F.relu          =   Hook(F.relu,_relu)torch.flatten   =   Hook(torch.flatten,_flatten)class Model(nn.Module):    def __init__(self):        super(Model, self).__init__()        self.conv1 = nn.Conv2d(1, 6, 5)        self.bn1   = nn.BatchNorm2d(6,eps=1e-5,momentum=0.1)        self.relu1 = nn.ReLU()    def forward(self, x):        y = self.conv1(x)        y = self.bn1(y)        y = self.relu1(y)         y = torch.flatten(y)        return yinput_var = torch.autograd.Variable(torch.rand(1, 1, 28, 28))nn = Model()nn.eval()Hook.hookInited = Trueres = nn(input_var)
  • 结果: flatten操作也完成了导出, 且每个op的input的ID都能在前面找到对应op的output的ID.即可知晓上下层之间的关系,由此,即可构建msnhnet.
conv2d-i 2748363239504conv2d-o 2748363238224bn-i 2748363238224bn-o 2748363242832relu-i 2748363242832relu-o 2748363235152flatten-i 2748363235152flatten-o 2748363242064

参数的转换

  • 「思路一」: 利用pytorch的state_dict字典,可直接进行导出. 由于msnhnet和pytorch的内存排布是一致的,都为NCHW模式,且对于BN层的参数顺序也相同,都为scale, bias, mean和var.只需将参数进行逐个提取,然后按二进制存储即可。
    • 优点: 可以在不知道网络运行结构的时候对参数进行导出, 简单易用.
    • 缺点: 当网络使用参数的顺序和保存的顺序不一致时,会出现错误.
import torchvision.models as modelsimport torchfrom struct import packmd = models.resnet18(pretrained = True)md.to("cpu")md.eval()val = []dd = 0for name in md.state_dict():        if "num_batches_tracked" not in name:                c = md.state_dict()[name].data.flatten().numpy().tolist()                dd = dd + len(c)                print(name, ":", len(c))                val.extend(c)with open("alexnet.msnhbin","wb") as f:    for i in val :        f.write(pack('f',i))

注意上面出现了一行if "num_batches_tracked" not in name:,这一行是Pytorch的一个坑点,在pytorch 0.4.1及后面的版本里,BatchNorm层新增了num_batches_tracked参数,用来统计训练时的forward过的batch数目,源码如下(pytorch0.4.1):

  if self.training and self.track_running_stats:        self.num_batches_tracked += 1        if self.momentum is None:  # use cumulative moving average            exponential_average_factor = 1.0 / self.num_batches_tracked.item()        else:  # use exponential moving average            exponential_average_factor = self.momentum

在调用预训练参数模型时,官方给定的预训练模型是在pytorch0.4之前。因此,调用预训练参数时,需要过滤掉“num_batches_tracked”。

  • 「思路二」: 利用之前的Hook,在算子运行时,对参数进行提取,暂存,最后统一保存.
    • 优点: 网络参数和网络结构同时导出,保证参数与网络运行结构一致性.
    • 缺点: 需要获取网络的运行顺序才能完成转换.
  • 代码实现:
...m_weights = []def _conv2d(raw,inData, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):        x=raw(inData,weight,bias,stride,padding,dilation,groups)    if Hook.hookInited :        log( "conv2d-i" , inData._cdata)        ccc.append(x)        log( "conv2d-o" , x._cdata)        useBias = True        if bias is None:            useBias = False                m_weights.extend(weight.numpy().flatten().tolist()) #暂存        if useBias :            m_weights.extend(bias.numpy().flatten().tolist()) #暂存        msnhnet.checkInput(inData,sys._getframe().f_code.co_name)        msnhnet.buildConv2d(str(x._cdata), x.size()[1], weight.size()[2], weight.size()[3],                             padding[0], padding[1], stride[0], stride[1], dilation[0], dilation[1], groups, useBias)    return x...def trans(net, inputVar, msnhnet_path, msnhbin_path):    Hook.hookInited = True    msnhnet.buildConfig(str(id(inputVar)), inputVar.size())    net.forward(inputVar)    with open(msnhnet_path,"w") as f1:        f1.write(msnhnet.net)    with open(msnhbin_path,"wb") as f: # 参数保存        for i in m_weights :            f.write(pack('f',i))     Hook.hookInited = False

详细转换过程代码编写

这里先截取一下构建MsnhNet的部分代码,完整代码见https://github.com/msnh2012/Msnhnet/blob/master/tools/pytorch2Msnhnet/PytorchToMsnhnet.py,如下:

from collections import OrderedDictimport sysclass Msnhnet:    def __init__(self):        self.inAddr = ""        self.net = ""        self.index = 0        self.names = []        self.indexes = []    def setNameAndIdx(self, name, ids):        self.names.append(name)        self.indexes.append(ids)    def getIndexFromName(self,name):        ids = self.indexes[self.names.index(name)]        return ids    def getLastVal(self):        return self.indexes[-1]    def getLastKey(self):        return self.names[-1]    def checkInput(self, inAddr,fun):        if self.index == 0:            return        if str(inAddr._cdata) != self.getLastKey():            try:                ID = self.getIndexFromName(str(inAddr._cdata))                self.buildRoute(str(inAddr._cdata),str(ID),False)            except:                 raise NotImplementedError("last op is not supported " + fun + str(inAddr._cdata))                def buildConfig(self, inAddr, shape):        self.inAddr = inAddr        self.net = self.net + "config:"        self.net = self.net + "  batch: " + str(int(shape[0])) + ""        self.net = self.net + "  channels: " + str(int(shape[1])) + ""        self.net = self.net + "  height: " + str(int(shape[2])) + ""        self.net = self.net + "  width: " + str(int(shape[3])) + ""     def buildConv2d(self, name, filters, kSizeX, kSizeY, paddingX, paddingY, strideX, strideY, dilationX, dilationY, groups, useBias):        self.setNameAndIdx(name,self.index)        self.net = self.net + "#" + str(self.index) +  ""        self.index = self.index + 1        self.net = self.net + "conv:"        self.net = self.net + "  filters: " + str(int(filters)) + ""        self.net = self.net + "  kSizeX: " + str(int(kSizeX)) + ""        self.net = self.net + "  kSizeY: " + str(int(kSizeY)) + ""        self.net = self.net + "  paddingX: " + str(int(paddingX)) + ""        self.net = self.net + "  paddingY: " + str(int(paddingY)) + ""        self.net = self.net + "  strideX: " + str(int(strideX)) + ""        self.net = self.net + "  strideY: " + str(int(strideY)) + ""        self.net = self.net + "  dilationX: " + str(int(dilationX)) + ""        self.net = self.net + "  dilationY: " + str(int(dilationY)) + ""        self.net = self.net + "  groups: " + str(int(groups)) + ""        self.net = self.net + "  useBias: " + str(int(useBias)) + ""

然后Pytorch2MsnhNet就在前向传播的过程中按照我们介绍的Hook技术完成构建Pytorch模型对应的MsnhNet模型结构。

至此,我们就获得了MsnhNet的模型参数文件和权重文件,可以利用MsnhNet加载模型进行推理了。

已经支持的OP以及转换实例

Pytorch2MsnhNet已经支持转换的OP如下:

-  conv2d-  max_pool2d-  avg_pool2d-  adaptive_avg_pool2d-  linear-  flatten-  dropout-  batch_norm-  interpolate(nearest, bilinear)-  cat   -  elu-  selu-  relu-  relu6-  leaky_relu-  tanh-  softmax-  sigmoid-  softplus-  abs    -  acos   -  asin   -  atan   -  cos    -  cosh   -  sin    -  sinh   -  tan    -  exp    -  log    -  log10  -  mean-  permute-  view-  contiguous-  sqrt-  pow-  sum-  pad-  +|-|x|/|+=|-=|x=|/=|
  • ResNet18的转换示例:
import torchimport torch.nn as nnfrom torchvision.models import resnet18from PytorchToMsnhnet import *resnet18=resnet18(pretrained=True)resnet18.eval()input=torch.ones([1,3,224,224])trans(resnet18, input,"resnet18.msnhnet","resnet18.msnhbin")
  • DeepLabV3的转换示例:
import torchimport torch.nn as nnfrom torchvision.models.segmentation import deeplabv3_resnet101from PytorchToMsnhnet import *deeplabv3=deeplabv3_resnet101(pretrained=False)ccc = torch.load("C:/Users/msnh/.cache/torch/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth")del ccc["aux_classifier.0.weight"]del ccc["aux_classifier.1.weight"]del ccc["aux_classifier.1.bias"]del ccc["aux_classifier.1.running_mean"]del ccc["aux_classifier.1.running_var"]del ccc["aux_classifier.1.num_batches_tracked"]del ccc["aux_classifier.4.weight"]del ccc["aux_classifier.4.bias"]deeplabv3.load_state_dict(ccc)deeplabv3.requires_grad_(False)deeplabv3.eval()input=torch.ones([1,3,224,224])# trans msnhnet and msnhbin filetrans(deeplabv3, input,"deeplabv3.msnhnet","deeplabv3.msnhbin")
d096a461bc804a789ec684930d88f224
import torchimport torch.nn as nnfrom torchvision.models.segmentation import deeplabv3_resnet101from PytorchToMsnhnet import *deeplabv3=deeplabv3_resnet101(pretrained=False)ccc = torch.load("C:/Users/msnh/.cache/torch/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth")del ccc["aux_classifier.0.weight"]del ccc["aux_classifier.1.weight"]del ccc["aux_classifier.1.bias"]del ccc["aux_classifier.1.running_mean"]del ccc["aux_classifier.1.running_var"]del ccc["aux_classifier.1.num_batches_tracked"]del ccc["aux_classifier.4.weight"]del ccc["aux_classifier.4.bias"]deeplabv3.load_state_dict(ccc)deeplabv3.requires_grad_(False)deeplabv3.eval()input=torch.ones([1,3,224,224])# trans msnhnet and msnhbin filetrans(deeplabv3, input,"deeplabv3.msnhnet","deeplabv3.msnhbin")

MsnhNet介绍

MsnhNet是一款基于纯c++的轻量级推理框架,此框架受到darknet启发,由穆士凝魂主导,并由本公众号作者团队业余协助开发。

项目地址:https://github.com/msnh2012/Msnhnet ,欢迎一键三连。

本框架目前已经支持了X86、Cuda、Arm端的推理(支持的OP有限,正努力开发中),并且可以直接将Pytorch模型(后面也会尝试接入更多框架)转为本框架的模型进行部署,欢迎对前向推理框架感兴趣的同学试用或者加入我们一起维护这个轮子。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值