python怎么使用预训练的模型_用MXnet预训练模型初始化Pytorch模型

1、MXnet符号图:

基于MXnet所构建的符号图是一种静态计算图,图结构与内存管理都是静态的。以Resnet50_v2为例,Bottleneck结构的符号图如下:

[python]bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')

act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')

conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),

no_bias=True, workspace=workspace, name=name + '_conv1')

bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')

act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')

conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1),

no_bias=True, workspace=workspace, name=name + '_conv2')

bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')

act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')

conv3 = mx.sym.Convolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,

workspace=workspace, name=name + '_conv3')

if dim_match:

shortcut = data

else:

shortcut = mx.sym.Convolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,

workspace=workspace, name=name+'_sc')

return conv3 + shortcut

2、加载符号图与模型参数:

MXnet预训练模型包括json配置文件与param参数文件:

-- resnet-50-0000.params

-- resnet-50-symbol.json

通过加载这两个文件,便可以获得符号图结构、模型权重与辅助参数信息:

[python]prefix, index, num_layer = 'resnet-50', args.epoch, 50

prefix = os.path.join(ROOT_PATH, "./mx_model/models/{}".format(prefix))

symbol, param_args, param_auxs = mx.model.load_checkpoint(prefix, index)

3、Pytorch动态图:

Pytorch是一种动态类型框架,计算图构建与内存管理都是动态的,适合专注于研究的算法开发。按照命令式编程方式,能够及时获取计算图中Tensor及其导数的数值信息。Resnet50_v2的Bottleneck结构如下:

[python]class Bottleneck(nn.Module):

expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=False):

super(Bottleneck, self).__init__()

self.bn1 = nn.BatchNorm2d(inplanes, eps)

self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)

self.bn2 = nn.BatchNorm2d(planes, eps)

self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,

padding=1, bias=False)

self.bn3 = nn.BatchNorm2d(planes, eps)

self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)

self.relu = nn.ReLU(inplace=True)

self.downsample = downsample

if downsample:

self.conv_sc = nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=stride, bias=False)

self.stride = stride

def forward(self, input):

out = self.bn1(input)

out1 = self.relu(out)

residual = input

out = self.conv1(out1)

out = self.bn2(out)

out = self.relu(out)

out = self.conv2(out)

out = self.bn3(out)

out = self.relu(out)

out = self.conv3(out)

if self.downsample:

residual = self.conv_sc(out1)

out += residual

return out

4、解析MXnet参数、初始化Pytorch模型:

首先需要将MXnet参数转为Numpy数组形式的字典。BN层、Conv2D层、FC层解析如下:

[python]def bn_parse(args, auxs, name, args_dict, fix_gamma=False):

""" name0: PyTorch layer name;

name1: MXnet layer name."""

args_dict[name[0]] = {}

if not fix_gamma:

args_dict[name[0]]['running_mean'] = auxs[name[1]+'_moving_mean'].asnumpy()

args_dict[name[0]]['running_var'] = auxs[name[1]+'_moving_var'].asnumpy()

args_dict[name[0]]['gamma'] = args[name[1]+'_gamma'].asnumpy()

args_dict[name[0]]['beta'] = args[name[1]+'_beta'].asnumpy()

else:

_mv = auxs[name[1]+'_moving_var'].asnumpy()

_mm = auxs[name[1]+'_moving_mean'].asnumpy() - np.multiply(args[name[1]+'_beta'].asnumpy(), np.sqrt(_mv+eps))

args_dict[name[0]]['running_mean'] = _mm

args_dict[name[0]]['running_var'] = _mv

return args_dict

[python]def conv_parse(args, auxs, name, args_dict):

""" name0: PyTorch layer name;

name1: MXnet layer name."""

args_dict[name[0]] = {}

w = args[name[1]+'_weight'].asnumpy()

args_dict[name[0]]['weight'] = w # N, M, k1, k2

return args_dict

[python]def fc_parse(args, auxs, name, args_dict):

""" name0: PyTorch layer name;

name1: MXnet layer name."""

args_dict[name[0]] = {}

args_dict[name[0]]['weight'] = args[name[1]+'_weight'].asnumpy()

args_dict[name[0]]['bias'] = args[name[1]+'_bias'].asnumpy()

return args_dict

然后逐层遍历Pytorch的每个module,并完成模型参数赋值,从而实现用MXnet预训练模型初始化Pytorch模型的目的:

[python]# model initialization for PyTorch from MXnet params

class resnet(object):

def __init__(self, name, num_layer, args, auxs, prefix='module.'):

self.name = name

num_stages = 4

if num_layer == 50:

units = [3, 4, 6, 3]

elif num_layer == 101:

units = [3, 4, 23, 3]

self.num_layer = str(num_layer)

self.param_dict = arg_parse(args, auxs, num_stages, units, prefix=prefix)

def bn_init(self, n, m):

if not (m.weight is None):

m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['gamma']))

m.bias.data.copy_(torch.FloatTensor(self.param_dict[n]['beta']))

m.running_mean.copy_(torch.FloatTensor(self.param_dict[n]['running_mean']))

m.running_var.copy_(torch.FloatTensor(self.param_dict[n]['running_var']))

def conv_init(self, n, m):

#m.weight.data.zero_()

m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['weight']))

def fc_init(self, n, m):

m.weight.data.copy_(torch.FloatTensor(self.param_dict[n]['weight']))

m.bias.data.copy_(torch.FloatTensor(self.param_dict[n]['bias']))

def init_model(self, model):

for n, m in model.named_modules():

if isinstance(m, nn.BatchNorm2d):

self.bn_init(n, m)

elif isinstance(m, nn.Conv2d):

self.conv_init(n, m)

elif isinstance(m, nn.Linear):

self.fc_init(n, m)

return model

5、使用MXnet的数据加载器:

mx.io.ImageRecordIter的输出转为Pytorch Tensor,便可用于Pytorch模型的训练、验证与测试,迭代器设计如下:

[python]def __iter__(self):

for batch in self.data:

nd_data = batch.data[0].asnumpy()

nd_label = batch.label[0].asnumpy()

input_data = torch.FloatTensor(nd_data)

input_label = torch.LongTensor(nd_label)

if self.cuda:

yield input_data.cuda(non_blocking=True), input_label.cuda(non_blocking=True)

else:

yield input_data, input_label

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值