Mxnet2Caffe mxnet模型转换caffe

Mxnet2Caffe

将mxnet静态图symbol转换为caffe的prototxt文本,支持大部分op,caffe不需要的op则需要自己添加,再转换,否则会构建失败

  1. 将json转换为prototxt
  2. 利用caffe的python接口构建网络,将mxnet的参数param迁移到caffe网络中
  3. 构建caffe不支持的op
  4. 对结果进行比对

json_2_protxt

  1. json2prototxt.py prototxt_basic.py Read mxnet_json file and converte to prototxt
// json 格式,只要就是op(操作节点和辅助节点null)  name    attr(参数列表)   inputs(输入列表list)
    {
      "op": "Activation", 
      "name": "part_0_stage1_unit1_relu1", 
      "attrs": {"act_type": "relu"}, 
      "inputs": [[14, 0, 0]]
    }, 
    {
      "op": "null", 
      "name": "part_0_stage1_unit1_conv1_weight", 
      "attrs": {
        "kernel": "(3, 3)", 
        "no_bias": "True", 
        "num_filter": "64", 
        "pad": "(1, 1)", 
        "stride": "(1, 1)", 
        "workspace": "256"
      }, 
      "inputs": []
    }, 
    {
      "op": "Convolution", 
      "name": "part_0_stage1_unit1_conv1", 
      "attrs": {
        "kernel": "(3, 3)", 
        "no_bias": "True", 
        "num_filter": "64", 
        "pad": "(1, 1)", 
        "stride": "(1, 1)", 
        "workspace": "256"
      }, 
      "inputs": [[15, 0, 0], [16, 0, 0]]
    }, 

读取json文件,并存储相应信息


with open(args.mx_json) as json_file:    
  jdata = json.load(json_file)

with open(args.cf_prototxt, "w") as prototxt_file:
  for i_node in range(0,len(jdata['nodes'])):
    #logging.info("i_node[%d],'name' %s" %(i_node,jdata['nodes'][i_node]['name']))
  
    node_i  = jdata['nodes'][i_node]
	# 如果当前节点是辅助节点或输入节点(只转换操作节点)  则跳过
    if str(node_i['op']) == 'null' and str(node_i['name']) != 'data':
      continue
    '''
    logging.info('%d, \top:%s, name:%s -> %s'.%(i_node,node_i['op'].ljust(20),
                                        node_i['name'].ljust(30),
                                        node_i['name']).ljust(20))
                                  '''
    ##node[i]个节点  存在的信息 op  name  param  input
    info = node_i
    info['top'] = info['name']
    info['bottom'] = []
    info['params'] = []
    
	# 遍历当前节点的输入  存储辅助参数
    for input_idx_i in node_i['inputs']:
      # jdata['nodes'][input_idx_i[0]]  jdana['nodes'][input_index]
      input_i = jdata['nodes'][input_idx_i[0]]
	  #存储所有输入节点
      if str(input_i['op']) != 'null' or (str(input_i['name']) == 'data'):
        info['bottom'].append(str(input_i['name']))

      if str(input_i['op']) == 'null':
        info['params'].append(str(input_i['name']))
        if not str(input_i['name']).startswith(str(node_i['name'])):
          logging.info('           use shared weight -> %s'% str(input_i['name']))
          info['share'] = True
    write_node(prototxt_file, info)
    

写prototxt文件

# 转换 Convolution 节点操作
def Convolution(txt_file, info):
  if info['attrs']['no_bias'] == 'True':
    bias_term = 'false'
  else:
    bias_term = 'true'  
  txt_file.write('layer {\n')
  txt_file.write('	bottom: "%s"\n'       % info['bottom'][0])
  txt_file.write('	top: "%s"\n'          % info['top'])
  txt_file.write('	name: "%s"\n'         % info['top'])
  txt_file.write('	type: "Convolution"\n')
  txt_file.write('	convolution_param {\n')
  txt_file.write('		num_output: %s\n'   % info['attrs']['num_filter'])
  txt_file.write('		kernel_size: %s\n'  % info['attrs']['kernel'].split('(')[1].split(',')[0]) # TODO
  if 'pad' not in info['attrs']:
    logging.info('miss Conv_pad, make pad default: 0 ')
    txt_file.write('		pad: %s\n' % 0)  # TODO
  else:
    txt_file.write('		pad: %s\n'          % info['attrs']['pad'].split('(')[1].split(',')[0]) # TODO
#  txt_file.write('		group: %s\n'        % info['attrs']['num_group'])
  txt_file.write('		stride: %s\n'       % info['attrs']['stride'].split('(')[1].split(',')[0])
  txt_file.write('		bias_term: %s\n'    % bias_term)
  txt_file.write('	}\n')
  if 'share' in info.keys() and info['share']:  
    txt_file.write('	param {\n')
    txt_file.write('	  name: "%s"\n'     % info['params'][0])
    txt_file.write('	}\n')
  txt_file.write('}\n')
  txt_file.write('\n')

# -------根据op操作,完善相应的转换函数-----------
# 目前包含Conv Pool DepthConv BN Act ele_add Concat FC Reshape etc. 
def write_node(txt_file, info):
    if 'label' in info['name']:
        return        
    if info['op'] == 'null' and info['name'] == 'data':
        data(txt_file, info)
    elif info['op'] == 'Convolution':
        Convolution(txt_file, info)
    elif info['op'] == 'ChannelwiseConvolution':
        ChannelwiseConvolution(txt_file, info)
    elif info['op'] == 'BatchNorm':
        BatchNorm(txt_file, info)
    elif info['op'] == 'Activation':
        Activation(txt_file, info)
#    elif info['op'] == 'ElementWiseSum':
    elif info['op'] == 'elemwise_add':
        ElementWiseSum(txt_file, info)
    elif info['op'] == '_Plus':
        ElementWiseSum(txt_file, info)
    elif info['op'] == 'Concat':
        Concat(txt_file, info)
    elif info['op'] == 'Pooling':
#        Pooling(txt_file, info)
        Pooling_global(txt_file, info)
    elif info['op'] == 'Flatten':
        Flatten(txt_file, info)
    elif info['op'] == 'FullyConnected':
        FullyConnected(txt_file, info)
    elif info['op'] == 'SoftmaxOutput':
        SoftmaxOutput(txt_file, info)
    elif info['op'] == 'Cast':
        Cast(txt_file, info)
    elif info['op'] == 'SliceChannel':
        SliceChannel(txt_file, info)
    elif info['op'] == 'L2Normalization':
        L2Normalization(txt_file, info)
    elif info['op'] == 'Reshape':
      Reshape(txt_file,info)
    elif info['op'] == 'broadcast_mul':
      broadcast_mul(txt_file,info)
    else:
        logging.warn("Unknown mxnet op: %s" %info['op'])

利用caffe的python接口,构建网络,并迁移mxnet的网络参数

1.mxnet2caffe.py Read mxnet_model params_dict and converte to .caffemodel

转换的时候如果存在caffe不支持的op,需要自己添加自定义层,否则在构建网络时,会error,本工程添加了broadcast_mul层caffe添加自定义层的介绍比较多,就跳过了

根据mxnet的API (load) 加载param文件的所有参数字典

try:
    import caffe
except ImportError:
    import os, sys
    sys.path.append("/home/***/codes/mx2caffe/caffe/python/")
    import caffe
#读取全部param 参数字典    
_, arg_params, aux_params = mx.model.load_checkpoint(args.mx_model, args.mx_epoch)
all_keys = arg_params.keys() + aux_params.keys()
# 利用caffe的python接口,读取刚转换的proto构建网络,
net = caffe.Net(args.cf_prototxt, caffe.TRAIN)

for i_key,key_i in enumerate(all_keys):
  try:    
    if 'data' is key_i:
      pass
    # 在mxnet字典中,存有caffe不需要的后缀,_weight _bias 
    # 需要确认caffe的参数保存顺序  [0]是weight  [1]是bias  其它op 类似查看proto结构设计
    elif '_weight' in key_i:
    
      key_caffe = key_i.replace('_weight','')
      net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
    elif '_bias' in key_i:
      key_caffe = key_i.replace('_bias','')
      net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat
    elif '_gamma' in key_i:
      key_caffe = key_i.replace('_gamma','_scale')
      net.params[key_caffe][0].data.flat = arg_params[key_i].asnumpy().flat
    elif '_beta' in key_i:
      key_caffe = key_i.replace('_beta','_scale')
      net.params[key_caffe][1].data.flat = arg_params[key_i].asnumpy().flat
    elif '_moving_mean' in key_i:
      key_caffe = key_i.replace('_moving_mean','')
      net.params[key_caffe][0].data.flat = aux_params[key_i].asnumpy().flat
      net.params[key_caffe][2].data[...] = 1
    elif '_moving_var' in key_i:
      key_caffe = key_i.replace('_moving_var','')
      net.params[key_caffe][1].data.flat = aux_params[key_i].asnumpy().flat
      net.params[key_caffe][2].data[...] = 1
    else:
      sys.exit("Warning!  Unknown mxnet:{}".format(key_i))
  
    print("% 3d | %s -> %s, initialized." 
           %(i_key, key_i.ljust(40), key_caffe.ljust(30)))
    
  except KeyError:
    print("\nWarning!  key error mxnet:{}".format(key_i))  
      
# ------------------------------------------
# Finish
net.save(args.cf_model)
print("\n- Finished.\n")

对转换结果进行比对确认

  1. mxnet_test.py Debug mxnet output and you can compare the result with the converted caffemodel

使用mxnet debug, 打印需要对比的参数,并且输出指定层的结果

import mxnet as mx

def load_checkpoint_single(model, param_path):
    arg_params = {}
    aux_params = {}
    save_dict = mx.nd.load(param_path)
    for k, value in save_dict.items():
        arg_type, name = k.split(':', 1)
        if arg_type == 'arg':
            arg_params[name] = value
        if arg_type == 'aux':
            aux_params[name] = value
        else :
            pass
    model.set_params(arg_params, aux_params, allow_missing=False)
    arg_params, aux_params = model.get_params()
    return arg_params, aux_params

full_param_path = 'se_resnet34/base-0000.params'
fmodel = mx.sym.load('se_resnet34/base-symbol.json')
# 获取mxnet网络的所有layer参数
all_layers = fmodel.get_internals()

# 修改这里为需要输出layer的name+output即可指定层输出 ‘name_output’
fmodel = all_layers['flat_output']
fullmodel = mx.mod.Module(symbol=fmodel,data_names=['data'],label_names=[])

img = []
img = get_image_gray('before_forward.jpg')
fullmodel.bind(data_shapes=[('data', (1, 1, 108, 108))], label_shapes=None, for_training=False, force_rebind=False)

arg_params, aux_params = load_checkpoint_single(fullmodel, full_param_path)
fullmodel.set_params(arg_params,aux_params)

file1=open('se_resnet34.txt','w')
tic=time.time()

fullmodel.forward(Batch([mx.nd.array(img)]))

prob = fullmodel.get_outputs()[0].asnumpy()
prob = prob.astype(np.float64)
prob = prob.reshape(-1,1)
# 以特定的格式保存结果
np.savetxt(file1,prob,fmt='%.12f')

file1.close()

然后利用Caffe 加载刚才转换的网络,打印输出,对比结果精度,如果出现问题,则需要逐层排查,本工程在SENet网络上测试正常

工程项目地址

https://github.com/junqiangwu/Mxnet2Caffe-Tensor-RT-SEnet

TODO:
  • add caffe_plugin_layer
  • Tensor RT load caffe_model
  • Tensor RT supported Se_Resnet
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值