mxnet迁移学习

一、加载模型与pretrain模型network相同

# !/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import time
import math
import mxnet as mx
import cv2
import numpy as np
from collections import namedtuple

# loading predict module
data_shape_G = 96
Batch = namedtuple('Batch',['data'])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix=r"~/meh_cla",epoch=2)

train,val = get_iterators(batch_size=batch_size, data_shape=(3, 96, 96))
train = Multi_mnist_iterator(train)
val = Multi_mnist_iterator(val)

model = mx.mod.Module(      # load pre train model
    symbol=sym,
    context=device,
    data_names=['data'],
    label_names=['softmax1_label','softmax2_label','softmax3_label']  # network structure
)
model.bind(data_shapes=train.provide_data, label_shapes=train.provide_label)
model.set_params(arg_params, aux_params, allow_missing=True)
model.fit(train, val,
          optimizer_params={'learning_rate': lr, 'momentum': 0.9},
          num_epoch=num_epochs,
          eval_metric=MAE_zz(name="mae"),
          batch_end_callback=mx.callback.Speedometer(batch_size, 2),
          epoch_end_callback=checkpoint
          )

 

二、加载模型与pretrain模型network不同
在这里插入图片描述

OK:

  sym, arg_params, aux_params = mx.model.load_checkpoint(r"model", 0)
    sym_new=get_symbol(512)
    mod_new=mx.mod.Module(symbol=sym_new,context=[mx.cpu()])#,label_names=None)
    data_shape_w=112
    # mod_new.binded=True
    mod_new.bind(for_training=False,data_shapes=[('data',(1,3,data_shape_w,data_shape_w))])
    arg_params_new=dict()
    aux_params_new=dict()
    for key in arg_params.keys():
        arg_params_new[key]=arg_params[key]
    for key in aux_params.keys():
        aux_params_new[key]=aux_params[key]

    mod_new.set_params(arg_params_new,aux_params_new,allow_missing=True)
    mod_new.save_checkpoint("asdf",0)

这个报错:

    sym, arg_params, aux_params = mx.model.load_checkpoint(r"model", 0)
    sym_new=get_symbol(512)
    mod_new=mx.mod.Module(symbol=sym_new,context=mx.cpu(0),label_names=None)
    data_shape_w=112
    # mod_new.binded=True
    mod_new.bind(for_training=False,data_shapes=[('data',(1,3,data_shape_w,data_shape_w))],label_shapes=mod_new.label_shapes)
    arg_params_new=dict()
    aux_params_new=dict()
    for key in arg_params.keys():
        arg_params_new[key]=arg_params[key]
    for key in aux_params.keys():
        aux_params_new[key]=aux_params[key]

    mod_new.set_params(arg_params_new,aux_params_new,allow_missing=True)
    # mod_new.save_checkpoint("asdf",0)

这个是mfnv2,是可以的

import mxnet as mx

bn_mom = 0.9
#bn_mom = 0.9997

def Act(data, act_type, name):
    #ignore param act_type, set it in this function
    body = mx.sym.LeakyReLU(data = data, act_type='prelu', name = name)
    #body = mx.sym.Activation(data=data, act_type='relu', name=name)
    return body

def Conv(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
    conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
    bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=bn_mom)
    act = Act(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix))
    return act

def Linear(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
    conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
    bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=bn_mom)
    return bn

def ConvOnly(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
    conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
    return conv


def DResidual(data, num_out=1, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=1, name=None, suffix=''):
    conv = Conv(data=data, num_filter=num_group, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name='%s%s_conv_sep' %(name, suffix))
    conv_dw = Conv(data=conv, num_filter=num_group, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name='%s%s_conv_dw' %(name, suffix))
    proj = Linear(data=conv_dw, num_filter=num_out, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name='%s%s_conv_proj' %(name, suffix))
    return proj

def Residual(data, num_block=1, num_out=1, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=1, name=None, suffix=''):
    identity=data
    for i in range(num_block):
        shortcut=identity
        conv=DResidual(data=identity, num_out=num_out, kernel=kernel, stride=stride, pad=pad, num_group=num_group, name='%s%s_block' %(name, suffix), suffix='%d'%i)
        identity=conv+shortcut
    return identity


def get_symbol(num_classes, **kwargs):
    global bn_mom
    bn_mom = kwargs.get('bn_mom', 0.9)
    wd_mult = kwargs.get('wd_mult', 1.)
    data = mx.symbol.Variable(name="data")
    data = data-127.5
    data = data*0.0078125
    conv_1 = Conv(data, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1")
    conv_2_dw = Conv(conv_1, num_group=64, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_2_dw")
    conv_23 = DResidual(conv_2_dw, num_out=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=128, name="dconv_23")
    conv_3 = Residual(conv_23, num_block=4, num_out=64, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=128, name="res_3")
    conv_34 = DResidual(conv_3, num_out=128, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=256, name="dconv_34")
    conv_4 = Residual(conv_34, num_block=6, num_out=128, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=256, name="res_4")
    conv_45 = DResidual(conv_4, num_out=128, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=512, name="dconv_45")
    conv_5 = Residual(conv_45, num_block=2, num_out=128, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=256, name="res_5")

    conv_6_sep = Conv(conv_5, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_6sep")
    conv_6_dw = Linear(conv_6_sep, num_filter=512, num_group=512, kernel=(7,7), pad=(0, 0), stride=(1, 1), name="conv_6dw7_7")
    #conv_6_dw = mx.symbol.Dropout(data=conv_6_dw, p=0.4)
    _weight = mx.symbol.Variable("pre_fc1_weight", shape=(num_classes, 512), lr_mult=1.0, wd_mult=wd_mult)

    conv_6_f = mx.sym.FullyConnected(data=conv_6_dw, weight=_weight, num_hidden=num_classes, name='pre_fc1')
    fc1 = mx.sym.BatchNorm(data=conv_6_f, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
    return fc1

if __name__ == '__main__':

    sym, arg_params, aux_params = mx.model.load_checkpoint(r"mv2_55\model", 0)
    sym_new=get_symbol(512)
    mod_new=mx.mod.Module(symbol=sym_new,context=[mx.cpu()])#,label_names=None)
    data_shape_w=112
    # mod_new.binded=True
    mod_new.bind(for_training=False,data_shapes=[('data',(1,3,data_shape_w,data_shape_w))])
    arg_params_new=dict()
    aux_params_new=dict()
    for key in arg_params.keys():
        arg_params_new[key]=arg_params[key]
    for key in aux_params.keys():
        aux_params_new[key]=aux_params[key]

    mod_new.set_params(arg_params_new,aux_params_new,allow_missing=True)
    mod_new.save_checkpoint("asdf",0)
    # ----------------------------
    # data_shape_w=112
    #
    # ctx = [mx.cpu()]
    # from easydict import EasyDict as edict
    # net = edict()
    #
    # net.ctx = ctx
    # net.sym, net.arg_params, net.aux_params =mx.model.load_checkpoint(r"E:\jinji\AidLearning-FrameWork-master\src\facencnn\models\mv2_55\model", 0)
    # #net.arg_params, net.aux_params = ch_dev(net.arg_params, net.aux_params, net.ctx)
    # all_layers = net.sym.get_internals()
    # net.sym = all_layers['fc1_output']
    # net.model = mx.mod.Module(symbol=net.sym,
    #     context=net.ctx,
    #     label_names=None)
    # net.model.bind(data_shapes=[('data', (1, 3, data_shape_w,data_shape_w))])
    # net.model.set_params(net.arg_params, net.aux_params,allow_missing=False)
    # net.model.save_checkpoint("asdf",0)


 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法加油站

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值