昇思MindSpore学习总结八——模型保存与加载

        在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,接下来将介绍如何保存与加载模型。

1.构建模型

import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor

def network():
    model = nn.SequentialCell(
                nn.Flatten(),
                nn.Dense(28*28, 512),
                nn.ReLU(),
                nn.Dense(512, 512),
                nn.ReLU(),
                nn.Dense(512, 10))
    return model

这里是没有经过训练的,可以直接用上一节训练的模型model。

2、保存和加载权重

2.1 保存

保存模型使用save_checkpoint接口,传入网络和指定的保存路径:

mindspore.save_checkpoint(save_objckpt_file_nameintegrated_save=Trueasync_save=Falseappend_dict=Noneenc_key=Noneenc_mode='AES-GCM'choice_func=None**kwargs)

【参数】

  • save_obj (Union[Cell, list, dict]) - 待保存的对象。数据类型可为 mindspore.nn.Cell 、list或dict。若为list,可以是 Cell.trainable_params() 的返回值,或元素为dict的列表(如[{“name”: param_name, “data”: param_data},…],param_name 的类型必须是str,param_data 的类型必须是Parameter或者Tensor);若为dict,可以是 mindspore.load_checkpoint() 的返回值。

  • ckpt_file_name (str) - checkpoint文件名称。如果文件已存在,将会覆盖原有文件。

  • integrated_save (bool) - 在并行场景下是否合并保存拆分的Tensor。默认值: True 。

  • async_save (bool) - 是否异步执行保存checkpoint文件。默认值: False 。

  • append_dict (dict) - 需要保存的其他信息。dict的键必须为str类型,dict的值类型必须是int、float、bool、string、Parameter或Tensor类型。默认值: None 。

  • enc_key (Union[None, bytes]) - 用于加密的字节类型密钥。如果值为 None ,那么不需要加密。默认值: None 。

  • enc_mode (str) - 该参数在 enc_key 不为 None 时有效,指定加密模式,目前仅支持 "AES-GCM" , "AES-CBC" 和 "SM4-CBC" 。默认值: "AES-GCM" 。

  • choice_func (function) - 一个用于自定义控制保存参数的函数。函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回 True ,则匹配自定义条件的Parameter将被保存。 如果返回 False ,则未匹配自定义条件的Parameter不会被保存。默认值: None 。

  • kwargs (dict) - 配置选项字典。

model = network()
mindspore.save_checkpoint(model, "model.ckpt")

运行之后会在同路径下找到一个文件

2.2 加载

        要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

2.2.1 load_checkpoint

mindspore.load_checkpoint(ckpt_file_namenet=Nonestrict_load=Falsefilter_prefix=Nonedec_key=Nonedec_mode='AES-GCM'specify_prefix=Nonechoice_func=None)

【参数】

  • ckpt_file_name (str) - checkpoint的文件名称。

  • net (Cell) - 加载checkpoint参数的网络。默认值: None 。

  • strict_load (bool) - 是否将严格加载参数到网络中。如果是 False ,它将根据相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行强制精度转换,比如将 float32 转换为 float16 。默认值: False 。

  • filter_prefix (Union[str, list[str], tuple[str]]) - 废弃(请参考参数 choice_func)。以 filter_prefix 开头的参数将不会被加载。默认值: None 。

  • dec_key (Union[None, bytes]) - 用于解密的字节类型密钥,如果值为 None ,则不需要解密。默认值: None 。

  • dec_mode (str) - 该参数仅当 dec_key 不为 None 时有效。指定解密模式,目前支持 "AES-GCM" , "AES-CBC" 和 "SM4-CBC" 。默认值: "AES-GCM" 。

  • specify_prefix (Union[str, list[str], tuple[str]]) - 废弃(请参考参数 choice_func)。以 specify_prefix 开头的参数将会被加载。默认值: None 。

  • choice_func (Union[None, function]) - 函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回 True ,则匹配自定义条件的Parameter将被加载。 如果返回 False ,则匹配自定义条件的Parameter将被删除。默认值: None 。

2.2.2 load_param_into_net

mindspore.load_param_into_net(netparameter_dictstrict_load=False)将参数加载到网络中,返回网络中没有被加载的参数列表。

【参数】

  • net (Cell) - 将要加载参数的网络。

  • parameter_dict (dict) - 加载checkpoint文件得到的字典。

  • strict_load (bool) - 是否将参数严格加载到网络中。如果是 False , 它将以相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行精度转换,比如将 float32 转换为 float16 。默认值: False 。

model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

【运行结果】

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。 

3、保存和加载MindIR

        除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。

3.1 保存

mindspore.export(net*inputsfile_namefile_format**kwargs)将MindSpore网络模型导出为指定格式的文件。

【参数】

  • net (Union[Cell, function]) - MindSpore网络结构。

  • inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]) - 网络的输入,如果网络有多个输入,需要一同传入。当传入的类型为 Dataset 时,将会把数据预处理行为同步保存起来。需要手动调整batch的大小,当前仅支持获取 Dataset 的 image 列。

  • file_name (str) - 导出模型的文件名称。

  • file_format (str) - MindSpore目前支持导出”AIR”,”ONNX”和”MINDIR”格式的模型。

    • AIR - Ascend Intermediate Representation。一种Ascend模型的中间表示格式。推荐的输出文件后缀是”.air”。

    • ONNX - Open Neural Network eXchange。一种针对机器学习所设计的开放式的文件格式。推荐的输出文件后缀是”.onnx”。

    • MINDIR - MindSpore Native Intermediate Representation for Anf。一种MindSpore模型的中间表示格式。推荐的输出文件后缀是”.mindir”。

  • kwargs (dict) - 配置选项字典。

    • enc_key (byte) - 用于加密的字节类型密钥,有效长度为16、24或者32。

    • enc_mode (Union[str, function]) - 指定加密模式,当设置 enc_key 时启用。

      • 对于 ‘AIR’和 ‘ONNX’格式的模型,当前仅支持自定义加密导出。

      • 对于 ‘MINDIR’格式的模型,支持的加密选项有: ‘AES-GCM’, ‘AES-CBC’, ‘SM4-CBC’和用户自定义加密算法。默认值: "AES-GCM"

      • 关于使用自定义加密导出的详情,请查看 教程

    • dataset (Dataset) - 指定数据集的预处理方法,用于将数据集的预处理导入MindIR。

    • obf_config (dict) - 模型混淆配置选项字典。

      • type (str) - 混淆类型,目前支持动态混淆,即 ‘dynamic’ 。

      • obf_ratio (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串 "small" 、 "medium" 、 "large" 。"small" 、"medium" 、"large" 分别对应于 0.1、0.3、0.6。

      • customized_func (function) - 在自定义函数模式下需要设置的Python函数,用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型,且是恒定的,用户可以参考不透明谓词进行设置(请查看 动态混淆教程 中的 my_func())。如果设置了 customized_func ,那么在使用 load 接口导入模型的时候,需要把这个函数也传入。

      • obf_random_seed (int) - 混淆随机种子,是一个取值范围为(0, 9223372036854775807]的整数,不同的随机种子会使模型混淆后的结构不同。如果用户设置了 obf_random_seed ,那么在部署混淆模型的时候,需要在调用 mindspore.nn.GraphCell 接口中传入 obf_random_seed 。需要注意的是,如果用户同时设置了 customized_func 和 obf_random_seed ,那么后一种模式将会被采用。

    • custom_func (function) - 用户自定义的导出策略的函数。该函数会在网络导出时,对模型使用该函数进行自定义处理。需要注意,当前仅支持对 format 为 MindIR 的文件使用 custom_func ,且自定义函数仅接受一个代表 MindIR 文件 Proto 对象的入参。当使用 custom_func 对模型进行修改时,需要保证修改后模型的正确性,否则可能导致模型加载失败或功能错误。默认值: None 。

model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。

3.2 加载

已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。nn.GraphCell仅支持图模式。

mindspore.nn.GraphCell(graphparams_init=Noneobf_random_seed=None)

运行从MindIR加载的计算图。

此功能仍在开发中。目前 GraphCell 不支持修改图结构,在导出MindIR时只能使用shape和类型与输入相同的数据。

【参数】

  • graph (FuncGraph) - 从MindIR加载的编译图。

  • params_init (dict) - 需要在图中初始化的参数。key为参数名称,类型为字符串,value为 Tensor 或 Parameter。如果参数名在图中已经存在,则更新其值;如果不存在,则忽略。默认值: None 。

  • obf_random_seed (Union[int, None]) - 用于动态混淆保护的混淆随机种子。动态混淆是一种模型保护方法,可以参考 mindspore.obfuscate_model() 。如果导入的 graph 是一个经过混淆的模型,那么须提供 obf_random_seed 。 obf_random_seed 的取值范围是(0, 9223372036854775807]。默认值: None 。

mindspore.set_context(mode=mindspore.GRAPH_MODE)

graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)

 这里时间改了一下,之前差8个小时。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值