pytorch模型转mxnet

介绍

gluon把mxnet再进行封装,封装的风格非常接近pytorch

使用gluon的好处是非常容易把pytorch模型向mxnet转化

唯一的问题是gluon封装还不成熟,封装好的layer不多,很多常用的layer 如concat,upsampling等layer都没有

这里关注如何把pytorch 模型快速转换成 mxnet基于symbol 和 exector设计的网络

pytorch转mxnet module

关键点:

  • mxnet 设计网络时symbol 名称要和pytorch初始化中各网络层名称对应 
  • torch.load()读入pytorch模型checkpoint 字典,取当中的'state_dict'元素,也是一个字典
  • pytorch state_dict 字典中key是网络层参数的名称,val是参数ndarray
  • pytorch 的参数名称的组织形式和mxnet一样,但是连接符号不同,pytorch是'.',而mxnet是'_'比如:

pytorch '0.conv1.0.weight'

mxnet  '0_conv1_0_weight'

  • pytorch 的参数array 和mxnet 的参数array 完全一样,只要名称对上,直接赋值即可初始化mxnet模型

需要做的有以下几点:

  • 设计和pytorch网络对应的mxnet网络
  • 加载pytorch checkpoint
  • 调整pytorch checkpoint state_dict 的key名称和mxnet命名格式一致

FlowNet2S PytorchToMxnet

pytorch flownet2S 的checkpoint 可以在github上搜到

import mxnet as mx
from symbol_util import *
import pickle

def get_loss(data, label, loss_scale, name, get_input=False, is_sparse = False, type='stereo'):

    if type == 'stereo':
        data = mx.sym.Activation(data=data, act_type='relu',name=name+'relu')
    # loss
    if  is_sparse:
        loss =mx.symbol.Custom(data=data, label=label, name=name, loss_scale= loss_scale, is_l1=True,
            op_type='SparseRegressionLoss')
    else:
        loss = mx.sym.MAERegressionOutput(data=data, label=label, name=name, grad_scale=loss_scale)
    return (loss,data) if get_input else loss


def flownet_s(loss_scale, is_sparse=False, name=''):
    img1 = mx.symbol.Variable('img1')
    img2 = mx.symbol.Variable('img2')
    data = mx.symbol.concat(img1,img2,dim=1)
    labels = {'loss{}'.format(i): mx.sym.Variable('loss{}_label'.format(i)) for i in range(0, 7)}
    # print('labels: ',labels)
    prediction = {}# a dict for loss collection
    loss = []#a list

    #normalize
    data = (data-125)/255

    # extract featrue
    conv1 = mx.sym.Convolution(data, pad=(3, 3), kernel=(7, 7), stride=(2, 2), num_filter=64, name=name + 'conv1_0')
    conv1 = mx.sym.LeakyReLU(data=conv1, act_type='leaky', slope=0.1)

    
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值