lasagne模型参数的查看、保存和读取

先定义模型,这是一个普通的CNN模型:

def build_cnn_feat_extractor(input_var=None, input_shape=(None, 3, 50, 50), n=128):
    assert isinstance(n, int)
    network = OrderedDict()
    network['input'] = lasagne.layers.InputLayer(shape=input_shape, input_var=input_var)
    network['conv1'] = lasagne.layers.Conv2DLayer(
        network['input'], num_filters=32, filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify,
        W=lasagne.init.GlorotUniform())
    network['pool1'] = lasagne.layers.MaxPool2DLayer(network['conv1'], pool_size=(2, 2))

    network['conv2'] = lasagne.layers.Conv2DLayer(
        network['pool1'], num_filters=64, filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify)
    network['pool2'] = lasagne.layers.MaxPool2DLayer(network['conv2'], pool_size=(2, 2))

    network['conv3'] = lasagne.layers.Conv2DLayer(
        network['pool2'], num_filters=n, filter_size=(3, 3),
        nonlinearity=lasagne.nonlinearities.rectify)
    network['pool3'] = lasagne.layers.MaxPool2DLayer(network['conv3'], pool_size=(2, 2))

    # A fully-connected layer of 256 units with 50% dropout on its inputs:
    network['fc1'] = lasagne.layers.DenseLayer(
        network['pool3'],
        num_units=256,
        nonlinearity=lasagne.nonlinearities.rectify)
    return network
    
def build_mt_cnn(input_var=None, classes=2, infer_classes=2, input_shape=(None, 3, 50, 50), n=128):
    network = build_cnn_feat_extractor(input_var, input_shape, n)
    network['fc2'] = lasagne.layers.DenseLayer(
        network['fc1'],
        num_units=classes,
        nonlinearity=lasagne.nonlinearities.linear)
    return network
    
#建立模型
network_dict = build_cnn(input_var, classes=10, input_shape=(None, 3, 50, 50))
network = network_dict['fc2']

查看模型参数 lasagne.layers.get_all_params(network)

经过一系列训练后,提取模型的参数,如果用 lasagne.layers.get_all_params(),得到的是参数的变量形式 [W, b],如果想要查看所有参数的值,可以用lasagne.layers.get_all_param_values()。

>>>params = lasagne.layers.get_all_params(network)
>>>params

Out[8]: [W, b, W, b, W, b, W, b, W, b]

'''
返回的params是一个10元素的列表,每个元素是一个数组,数组的维度对应的是每一层的参数。
'''
对于每一个参数来说,都是共享变量,可以用get_value()查看值:

```python
>>>params[0].get_value().shape
Out[10]: (32L, 3L, 3L, 3L)

保存模型参数

这里还是推荐用pickle,它可以直接保存列表形式的参数。

def save_or_load_model_params(save, network):
    if save:
        params = lasagne.layers.get_all_param_values(network)
        with open('./model/params.pickle', 'wb') as f:
            #把模型参数倒入文件中
            pickle.dump(params, f)
    else:
        with open('./model/params.pickle', 'rb') as f:
            params = pickle.load(f)
    #params是列表形式
    return params
#保存模型
save_or_load_model_params(0,  network)

读取模型参数 lasagne.layers.set_all_param_values(network, params)

需要重新定义一个相同结构的模型。

input_var = T.tensor4('inputs')
target_var = T.ivector('targets')

network_dict = build_cnn(input_var, classes=10, input_shape=(None, 3, 50, 50))
network = network_dict['fc2']

params = save_or_load_model_params(1)
lasagne.layers.set_all_param_values(network, params)

要注意的是,不能用pickle直接保存之前的模型(sklearn的模型可以用pickle直接保存),不然读取出来它的 input_var、target_var 是对应不上的。如果实在不想保存参数,可以建一个列表把模型的[input_var、target_var、模型]都放入列表中,再用pickle保存,这样读取出来的模型就是原模型了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

学渣渣渣渣渣

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值