mxnet.gluon-blocks.save_params()

本文档介绍了如何使用MXNet的Gluon Block类的save_parameters()函数来保存ResNet-18网络的参数。通过示例展示了保存的参数结构,包括权重、偏置、gamma、beta、running_mean和running_var等。这些参数以层序号或自定义Block名加参数名的方式组织。
摘要由CSDN通过智能技术生成

定义一个resnet-18网络

import gluonbook as gb
from mxnet.gluon import Trainer,data as gdata, nn
from mxnet import init, nd
import os
import sys


class Residual(nn.Block):  # 本类已保存在 gluonbook 包中方便以后使用。
    def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
                               strides=strides)
        self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
                                   strides=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm()
        self.bn2 = nn.BatchNorm()

    def forward(self, X):
        Y = nd.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return nd.relu(Y + X)

def resnet_block(num_channels, num_residuals, first_block =False):
    blk = nn.Sequential()
    for i in range(num_residuals):
        if i==0 and not first_block:
            blk.add(Residual(num_channels,use_1x1conv=True,strides=2))
        else:
            blk.add(Residual(num_channels))
    return blk


net = nn.Sequential()
net.add(nn.Conv2D(64,kernel_size=11,padding=3,strides=2),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.MaxPool2D(pool_size=3,strides=2,padding=1),
        resnet_block(64,4,first_block=True),
        resnet_block(128,4),
        resnet_block(256,4),
        resnet_block(512,4),
        nn.GlobalAvgPool2D(),
        nn.Dense(1024),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.Dropout(0.4),
        nn.Dense(10))

随机初始化网络并且保存数据,使用block类自带的save_parameters() 成员函数:

net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
new_filename = 'tmp.params'
net.save_parameters(new_filename)

然后就是读进来数据,分析保存的结构

#load params for analyzation
params = nd.load('tmp.params')
#params is a dict
print(isinstance(params,dict))
#print dict members'names
for key in params:
    print(key)

nd.load的结果是一个字典,字典的keys的打印结果如下:

  • 0.weight
  • 0.bias
  • 1.gamma
  • 1.beta
  • 1.running_mean
  • 1.running_var
  • 4.0.conv1.weight
  • 4.0.conv1.bias
  • 4.0.conv2.weight
  • 4.0.conv2.bias
  • 4.0.bn1.gamma
  • 4.0.bn1.beta
  • 4.0.bn1.running_mean
  • 4.0.bn1.running_var
  • 4.0.bn2.gamma
  • 4.0.bn2.beta
  • 4.0.bn2.running_mean
  • 4.0.bn2.running_var
  • 4.1.conv1.weight
  • 4.1.conv1.bias
  • 4.1.conv2.weight
  • 4.1.conv2.bias
  • 4.1.bn1.gamma
  • 4.1.bn1.beta
  • 4.1.bn1.running_mean
  • 4.1.bn1.running_var
  • 4.1.bn2.gamma
  • 4.1.bn2.beta
  • 4.1.bn2.running_mean
  • 4.1.bn2.running_var
  • 4.2.conv1.weight
  • 4.2.conv1.bias
  • 4.2.conv2.weight
  • 4.2.conv2.bias
  • 4.2.bn1.gamma
  • 4.2.bn1.beta
  • 4.2.bn1.running_mean
  • 4.2.bn1.running_var
  • 4.2.bn2.gamma
  • 4.2.bn2.beta
  • 4.2.bn2.running_mean
  • 4.2.bn2.running_var
  • 4.3.conv1.weight
  • 4.3.conv1.bias
  • 4.3.conv2.weight
  • 4.3.conv2.bias
  • 4.3.bn1.gamma
  • 4.3.bn1.beta
  • 4.3.bn1.running_mean
  • 4.3.bn1.running_var
  • 4.3.bn2.gamma
  • 4.3.bn2.beta
  • 4.3.bn2.running_mean
  • 4.3.bn2.running_var
  • 5.0.conv1.weight
  • 5.0.conv1.bias
  • 5.0.conv2.weight
  • 5.0.conv2.bias
  • 5.0.conv3.weight
  • 5.0.conv3.bias
  • 5.0.bn1.gamma
  • 5.0.bn1.beta
  • 5.0.bn1.running_mean
  • 5.0.bn1.running_var
  • 5.0.bn2.gamma
  • 5.0.bn2.beta
  • 5.0.bn2.running_mean
  • 5.0.bn2.running_var
  • 5.1.conv1.weight
  • 5.1.conv1.bias
  • 5.1.conv2.weight
  • 5.1.conv2.bias
  • 5.1.bn1.gamma
  • 5.1.bn1.beta
  • 5.1.bn1.running_mean
  • 5.1.bn1.running_var
  • 5.1.bn2.gamma
  • 5.1.bn2.beta
  • 5.1.bn2.running_mean
  • 5.1.bn2.running_var
  • 5.2.conv1.weight
  • 5.2.conv1.bias
  • 5.2.conv2.weight
  • 5.2.conv2.bias
  • 5.2.bn1.gamma
  • 5.2.bn1.beta
  • 5.2.bn1.running_mean
  • 5.2.bn1.running_var
  • 5.2.bn2.gamma
  • 5.2.bn2.beta
  • 5.2.bn2.running_mean
  • 5.2.bn2.running_var
  • 5.3.conv1.weight
  • 5.3.conv1.bias
  • 5.3.conv2.weight
  • 5.3.conv2.bias
  • 5.3.bn1.gamma
  • 5.3.bn1.beta
  • 5.3.bn1.running_mean
  • 5.3.bn1.running_var
  • 5.3.bn2.gamma
  • 5.3.bn2.beta
  • 5.3.bn2.running_mean
  • 5.3.bn2.running_var
  • 6.0.conv1.weight
  • 6.0.conv1.bias
  • 6.0.conv2.weight
  • 6.0.conv2.bias
  • 6.0.conv3.weight
  • 6.0.conv3.bias
  • 6.0.bn1.gamma
  • 6.0.bn1.beta
  • 6.0.bn1.running_mean
  • 6.0.bn1.running_var
  • 6.0.bn2.gamma
  • 6.0.bn2.beta
  • 6.0.bn2.running_mean
  • 6.0.bn2.running_var
  • 6.1.conv1.weight
  • 6.1.conv1.bias
  • 6.1.conv2.weight
  • 6.1.conv2.bias
  • 6.1.bn1.gamma
  • 6.1.bn1.beta
  • 6.1.bn1.running_mean
  • 6.1.bn1.running_var
  • 6.1.bn2.gamma
  • 6.1.bn2.beta
  • 6.1.bn2.running_mean
  • 6.1.bn2.running_var
  • 6.2.conv1.weight
  • 6.2.conv1.bias
  • 6.2.conv2.weight
  • 6.2.conv2.bias
  • 6.2.bn1.gamma
  • 6.2.bn1.beta
  • 6.2.bn1.running_mean
  • 6.2.bn1.running_var
  • 6.2.bn2.gamma
  • 6.2.bn2.beta
  • 6.2.bn2.running_mean
  • 6.2.bn2.running_var
  • 6.3.conv1.weight
  • 6.3.conv1.bias
  • 6.3.conv2.weight
  • 6.3.conv2.bias
  • 6.3.bn1.gamma
  • 6.3.bn1.beta
  • 6.3.bn1.running_mean
  • 6.3.bn1.running_var
  • 6.3.bn2.gamma
  • 6.3.bn2.beta
  • 6.3.bn2.running_mean
  • 6.3.bn2.running_var
  • 7.0.conv1.weight
  • 7.0.conv1.bias
  • 7.0.conv2.weight
  • 7.0.conv2.bias
  • 7.0.conv3.weight
  • 7.0.conv3.bias
  • 7.0.bn1.gamma
  • 7.0.bn1.beta
  • 7.0.bn1.running_mean
  • 7.0.bn1.running_var
  • 7.0.bn2.gamma
  • 7.0.bn2.beta
  • 7.0.bn2.running_mean
  • 7.0.bn2.running_var
  • 7.1.conv1.weight
  • 7.1.conv1.bias
  • 7.1.conv2.weight
  • 7.1.conv2.bias
  • 7.1.bn1.gamma
  • 7.1.bn1.beta
  • 7.1.bn1.running_mean
  • 7.1.bn1.running_var
  • 7.1.bn2.gamma
  • 7.1.bn2.beta
  • 7.1.bn2.running_mean
  • 7.1.bn2.running_var
  • 7.2.conv1.weight
  • 7.2.conv1.bias
  • 7.2.conv2.weight
  • 7.2.conv2.bias
  • 7.2.bn1.gamma
  • 7.2.bn1.beta
  • 7.2.bn1.running_mean
  • 7.2.bn1.running_var
  • 7.2.bn2.gamma
  • 7.2.bn2.beta
  • 7.2.bn2.running_mean
  • 7.2.bn2.running_var
  • 7.3.conv1.weight
  • 7.3.conv1.bias
  • 7.3.conv2.weight
  • 7.3.conv2.bias
  • 7.3.bn1.gamma
  • 7.3.bn1.beta
  • 7.3.bn1.running_mean
  • 7.3.bn1.running_var
  • 7.3.bn2.gamma
  • 7.3.bn2.beta
  • 7.3.bn2.running_mean
  • 7.3.bn2.running_var
  • 9.weight
  • 9.bias
  • 10.gamma
  • 10.beta
  • 10.running_mean
  • 10.running_var
  • 13.weight
  • 13.bias

从字典的key可以看出,这个key的组成

  • 如果是sequence,则用sequence序号代表该layer
  • 如果是自定义的继承block类的对象,则使用自己定义的layer的名字
  • 最后一个元素是block类中_reg_params这个字典成员变量中,参数的key,

_reg_params中以字典的形式保存layer对应的参数,如conv2d的_reg_params为:

conv2d._reg_params={'weight':NDArray,'bias':NDArray}

 

block.save_parameters()

下面看block.save_parameters()这个函数如何把block对象的参数保存成上面的样子

#block的成员函数,用递归的方式收集block对象所有的参数
#block对象可能是多层定义的,因此这里使用了基于DFS的搜索方法

def _collect_params_with_prefix(self, prefix=''):
    if prefix:
        prefix += '.'
    #添加该block的参数
    ret = {prefix + key : val for key, val in self._reg_params.items()}
    #添加该block下的子block的参数
    for name, child in self._children.items():
        #递归,输入传前缀,前缀是当前block的输入前缀+该子block的key
        #如果是sequence,key是一个数字,代表该block在sequence中的位置0,1,2,3……
        #如果是自定义的block,则是自定义的名称
        #字典的update操作等于拼接两个字典
        ret.update(child._collect_params_with_prefix(prefix + name))
    return ret

def save_parameters(self, filename):
    #得到列出所有参数的字典
    params = self._collect_params_with_prefix()
    #这一步应该是转化成cpu下的NDArray
    arg_dict = {key : val._reduce() for key, val in params.items()}
    #ndarray类的保存字典函数
    ndarray.save(filename, arg_dict)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值