mxnet学习(9):使用gluon接口读取symbol预训练模型finetune

使用symbol得到的模型或者gluon的hybridize之后的模型包括一个.json文件(网络结构)和.params文件(参数),gluon可以使用net = gluon.SymbolBlock.imports(json, ['data'], params, ctx)导入网络和参数,这样可以进行测试或者进一步训练。

但是如果只需要使用模型的其中一部分,比如只需要conv层,去掉所有fc层,或者再另外增加一些层, 这样直接导入就会比较复杂。正确的做法如下:

sym, arg_params, aux_params = mx.model.load_checkpoint("1.0.3", 40)#这里是model的名字和参数对应的epoch
layers = sym.get_internals()#得到所有的layers
outputs = layers['stage4_unit1_conv2_output']#选择输出层
inputs = layers['data']#选择输入层
net = gluon.SymbolBlock(outputs, inputs)#使用gluon的接口将其封装成一个新的net
net.load_parameters("1.0.3-0040.params", ignore_extra = True, allow_missing = True)#载入数据
y = net(data)
print(y.shape)

如果需要在该网络的基础上再新增加一些层,如下定义:

class PretrainedNetwork(gluon.HybridBlock):
    def __init__(self, pretrained_layer, **kwargs):
        super(PretrainedNetwork, self).__init__(**kwargs)
        with self.name_scope():
            self.pretrained_layer = pretrained_layer #(n, 4, 4, 128)
            self.fc = nn.HybridSequential()
            self.fc.add(
                        nn.Flatten(),
                        nn.Dense(256, activation = 'relu'),
                        nn.Dropout(rate = 0.5),
                        nn.Dense(128)
                        )
            self.single_fc = nn.Dense(2)
            self.fusion_fc = nn.Dense(2)
            
    def hybrid_forward(self, F, x):
        x = self.pretrained_layer(x)
        x = self.fc(x)
        feat = x
        y1 = self.single_fc(x)
        feat = feat.sum(axis = 1)
        y2 = self.fusion_fc(feat)
        return y1, y2

那么可以通过下面的方式,使用预训练模型初始化其中一部分:

net = PretrainedNetwork(pretrained_layer = net)
net.initialize(forece_reinit = False, init = init.Xavier())

需要注意的是,要先load_parameters再用其初始化PretrainedNwtwork,否则容易出现prefix不匹配的问题。

如果需要fix其中一部分参数,只训练其中一部分,可以通过观察所有layer的名字,找到需要训练的layer。

print(net.collect_params())#打印所有的参数,这样可以看到所有的layer及其参数

Trainerparams通过正则表达式选择需要训练的参数:

trainer = gluon.Trainer(params = net.collect_params("pretrained*|dense0*"), optimizer = optimizer)

这样没有被选中的参数就会被fix,训练中不会改变。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值