一、加载模型与pretrain模型network相同
# !/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import time
import math
import mxnet as mx
import cv2
import numpy as np
from collections import namedtuple
# loading predict module
data_shape_G = 96
Batch = namedtuple('Batch',['data'])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix=r"~/meh_cla",epoch=2)
train,val = get_iterators(batch_size=batch_size, data_shape=(3, 96, 96))
train = Multi_mnist_iterator(train)
val = Multi_mnist_iterator(val)
model = mx.mod.Module( # load pre train model
symbol=sym,
context=device,
data_names=['data'],
label_names=['softmax1_label','softmax2_label','softmax3_label'] # network structure
)
model.bind(data_shapes=train.provide_data, label_shapes=train.provide_label)
model.set_params(arg_params, aux_params, allow_missing=True)
model.fit(train, val,
optimizer_params={'learning_rate': lr, 'momentum': 0.9},
num_epoch=num_epochs,
eval_metric=MAE_zz(name="mae"),
batch_end_callback=mx.callback.Speedometer(batch_size, 2),
epoch_end_callback=checkpoint
)
二、加载模型与pretrain模型network不同
OK:
sym, arg_params, aux_params = mx.model.load_checkpoint(r"model", 0)
sym_new=get_symbol(512)
mod_new=mx.mod.Module(symbol=sym_new,context=[mx.cpu()])#,label_names=None)
data_shape_w=112
# mod_new.binded=True
mod_new.bind(for_training=False,data_shapes=[('data',(1,3,data_shape_w,data_shape_w))])
arg_params_new=dict()
aux_params_new=dict()
for key in arg_params.keys():
arg_params_new[key]=arg_params[key]
for key in aux_params.keys():
aux_params_new[key]=aux_params[key]
mod_new.set_params(arg_params_new,aux_params_new,allow_missing=True)
mod_new.save_checkpoint("asdf",0)
这个报错:
sym, arg_params, aux_params = mx.model.load_checkpoint(r"model", 0)
sym_new=get_symbol(512)
mod_new=mx.mod.Module(symbol=sym_new,context=mx.cpu(0),label_names=None)
data_shape_w=112
# mod_new.binded=True
mod_new.bind(for_training=False,data_shapes=[('data',(1,3,data_shape_w,data_shape_w))],label_shapes=mod_new.label_shapes)
arg_params_new=dict()
aux_params_new=dict()
for key in arg_params.keys():
arg_params_new[key]=arg_params[key]
for key in aux_params.keys():
aux_params_new[key]=aux_params[key]
mod_new.set_params(arg_params_new,aux_params_new,allow_missing=True)
# mod_new.save_checkpoint("asdf",0)
这个是mfnv2,是可以的
import mxnet as mx
bn_mom = 0.9
#bn_mom = 0.9997
def Act(data, act_type, name):
#ignore param act_type, set it in this function
body = mx.sym.LeakyReLU(data = data, act_type='prelu', name = name)
#body = mx.sym.Activation(data=data, act_type='relu', name=name)
return body
def Conv(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=bn_mom)
act = Act(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix))
return act
def Linear(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=bn_mom)
return bn
def ConvOnly(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
return conv
def DResidual(data, num_out=1, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=1, name=None, suffix=''):
conv = Conv(data=data, num_filter=num_group, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name='%s%s_conv_sep' %(name, suffix))
conv_dw = Conv(data=conv, num_filter=num_group, num_group=num_group, kernel=kernel, pad=pad, stride=stride, name='%s%s_conv_dw' %(name, suffix))
proj = Linear(data=conv_dw, num_filter=num_out, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name='%s%s_conv_proj' %(name, suffix))
return proj
def Residual(data, num_block=1, num_out=1, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=1, name=None, suffix=''):
identity=data
for i in range(num_block):
shortcut=identity
conv=DResidual(data=identity, num_out=num_out, kernel=kernel, stride=stride, pad=pad, num_group=num_group, name='%s%s_block' %(name, suffix), suffix='%d'%i)
identity=conv+shortcut
return identity
def get_symbol(num_classes, **kwargs):
global bn_mom
bn_mom = kwargs.get('bn_mom', 0.9)
wd_mult = kwargs.get('wd_mult', 1.)
data = mx.symbol.Variable(name="data")
data = data-127.5
data = data*0.0078125
conv_1 = Conv(data, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1")
conv_2_dw = Conv(conv_1, num_group=64, num_filter=64, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name="conv_2_dw")
conv_23 = DResidual(conv_2_dw, num_out=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=128, name="dconv_23")
conv_3 = Residual(conv_23, num_block=4, num_out=64, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=128, name="res_3")
conv_34 = DResidual(conv_3, num_out=128, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=256, name="dconv_34")
conv_4 = Residual(conv_34, num_block=6, num_out=128, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=256, name="res_4")
conv_45 = DResidual(conv_4, num_out=128, kernel=(3, 3), stride=(2, 2), pad=(1, 1), num_group=512, name="dconv_45")
conv_5 = Residual(conv_45, num_block=2, num_out=128, kernel=(3, 3), stride=(1, 1), pad=(1, 1), num_group=256, name="res_5")
conv_6_sep = Conv(conv_5, num_filter=512, kernel=(1, 1), pad=(0, 0), stride=(1, 1), name="conv_6sep")
conv_6_dw = Linear(conv_6_sep, num_filter=512, num_group=512, kernel=(7,7), pad=(0, 0), stride=(1, 1), name="conv_6dw7_7")
#conv_6_dw = mx.symbol.Dropout(data=conv_6_dw, p=0.4)
_weight = mx.symbol.Variable("pre_fc1_weight", shape=(num_classes, 512), lr_mult=1.0, wd_mult=wd_mult)
conv_6_f = mx.sym.FullyConnected(data=conv_6_dw, weight=_weight, num_hidden=num_classes, name='pre_fc1')
fc1 = mx.sym.BatchNorm(data=conv_6_f, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
return fc1
if __name__ == '__main__':
sym, arg_params, aux_params = mx.model.load_checkpoint(r"mv2_55\model", 0)
sym_new=get_symbol(512)
mod_new=mx.mod.Module(symbol=sym_new,context=[mx.cpu()])#,label_names=None)
data_shape_w=112
# mod_new.binded=True
mod_new.bind(for_training=False,data_shapes=[('data',(1,3,data_shape_w,data_shape_w))])
arg_params_new=dict()
aux_params_new=dict()
for key in arg_params.keys():
arg_params_new[key]=arg_params[key]
for key in aux_params.keys():
aux_params_new[key]=aux_params[key]
mod_new.set_params(arg_params_new,aux_params_new,allow_missing=True)
mod_new.save_checkpoint("asdf",0)
# ----------------------------
# data_shape_w=112
#
# ctx = [mx.cpu()]
# from easydict import EasyDict as edict
# net = edict()
#
# net.ctx = ctx
# net.sym, net.arg_params, net.aux_params =mx.model.load_checkpoint(r"E:\jinji\AidLearning-FrameWork-master\src\facencnn\models\mv2_55\model", 0)
# #net.arg_params, net.aux_params = ch_dev(net.arg_params, net.aux_params, net.ctx)
# all_layers = net.sym.get_internals()
# net.sym = all_layers['fc1_output']
# net.model = mx.mod.Module(symbol=net.sym,
# context=net.ctx,
# label_names=None)
# net.model.bind(data_shapes=[('data', (1, 3, data_shape_w,data_shape_w))])
# net.model.set_params(net.arg_params, net.aux_params,allow_missing=False)
# net.model.save_checkpoint("asdf",0)