在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,接下来将介绍如何保存与加载模型。
1.构建模型
import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor
def network():
model = nn.SequentialCell(
nn.Flatten(),
nn.Dense(28*28, 512),
nn.ReLU(),
nn.Dense(512, 512),
nn.ReLU(),
nn.Dense(512, 10))
return model
这里是没有经过训练的,可以直接用上一节训练的模型model。
2、保存和加载权重
2.1 保存
保存模型使用save_checkpoint
接口,传入网络和指定的保存路径:
mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode='AES-GCM', choice_func=None, **kwargs)
【参数】
-
save_obj (Union[Cell, list, dict]) - 待保存的对象。数据类型可为 mindspore.nn.Cell 、list或dict。若为list,可以是 Cell.trainable_params() 的返回值,或元素为dict的列表(如[{“name”: param_name, “data”: param_data},…],param_name 的类型必须是str,param_data 的类型必须是Parameter或者Tensor);若为dict,可以是 mindspore.load_checkpoint() 的返回值。
-
ckpt_file_name (str) - checkpoint文件名称。如果文件已存在,将会覆盖原有文件。
-
integrated_save (bool) - 在并行场景下是否合并保存拆分的Tensor。默认值:
True
。 -
async_save (bool) - 是否异步执行保存checkpoint文件。默认值:
False
。 -
append_dict (dict) - 需要保存的其他信息。dict的键必须为str类型,dict的值类型必须是int、float、bool、string、Parameter或Tensor类型。默认值:
None
。 -
enc_key (Union[None, bytes]) - 用于加密的字节类型密钥。如果值为
None
,那么不需要加密。默认值:None
。 -
enc_mode (str) - 该参数在 enc_key 不为
None
时有效,指定加密模式,目前仅支持"AES-GCM"
,"AES-CBC"
和"SM4-CBC"
。默认值:"AES-GCM"
。 -
choice_func (function) - 一个用于自定义控制保存参数的函数。函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回
True
,则匹配自定义条件的Parameter将被保存。 如果返回False
,则未匹配自定义条件的Parameter不会被保存。默认值:None
。 -
kwargs (dict) - 配置选项字典。
model = network()
mindspore.save_checkpoint(model, "model.ckpt")
运行之后会在同路径下找到一个文件
2.2 加载
要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpoint
和load_param_into_net
方法加载参数。
2.2.1 load_checkpoint
mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode='AES-GCM', specify_prefix=None, choice_func=None)
【参数】
-
ckpt_file_name (str) - checkpoint的文件名称。
-
net (Cell) - 加载checkpoint参数的网络。默认值:
None
。 -
strict_load (bool) - 是否将严格加载参数到网络中。如果是
False
,它将根据相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行强制精度转换,比如将 float32 转换为 float16 。默认值:False
。 -
filter_prefix (Union[str, list[str], tuple[str]]) - 废弃(请参考参数 choice_func)。以 filter_prefix 开头的参数将不会被加载。默认值:
None
。 -
dec_key (Union[None, bytes]) - 用于解密的字节类型密钥,如果值为
None
,则不需要解密。默认值:None
。 -
dec_mode (str) - 该参数仅当 dec_key 不为
None
时有效。指定解密模式,目前支持"AES-GCM"
,"AES-CBC"
和"SM4-CBC"
。默认值:"AES-GCM"
。 -
specify_prefix (Union[str, list[str], tuple[str]]) - 废弃(请参考参数 choice_func)。以 specify_prefix 开头的参数将会被加载。默认值:
None
。 -
choice_func (Union[None, function]) - 函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回
True
,则匹配自定义条件的Parameter将被加载。 如果返回False
,则匹配自定义条件的Parameter将被删除。默认值:None
。
2.2.2 load_param_into_net
mindspore.load_param_into_net(net, parameter_dict, strict_load=False)将参数加载到网络中,返回网络中没有被加载的参数列表。
【参数】
-
net (Cell) - 将要加载参数的网络。
-
parameter_dict (dict) - 加载checkpoint文件得到的字典。
-
strict_load (bool) - 是否将参数严格加载到网络中。如果是
False
, 它将以相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行精度转换,比如将 float32 转换为 float16 。默认值:False
。
model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
【运行结果】
param_not_load
是未被加载的参数列表,为空时代表所有参数均加载成功。
3、保存和加载MindIR
除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export
接口直接将模型保存为MindIR。
3.1 保存
mindspore.export(net, *inputs, file_name, file_format, **kwargs)将MindSpore网络模型导出为指定格式的文件。
【参数】
-
net (Union[Cell, function]) - MindSpore网络结构。
-
inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]) - 网络的输入,如果网络有多个输入,需要一同传入。当传入的类型为 Dataset 时,将会把数据预处理行为同步保存起来。需要手动调整batch的大小,当前仅支持获取 Dataset 的 image 列。
-
file_name (str) - 导出模型的文件名称。
-
file_format (str) - MindSpore目前支持导出”AIR”,”ONNX”和”MINDIR”格式的模型。
-
AIR - Ascend Intermediate Representation。一种Ascend模型的中间表示格式。推荐的输出文件后缀是”.air”。
-
ONNX - Open Neural Network eXchange。一种针对机器学习所设计的开放式的文件格式。推荐的输出文件后缀是”.onnx”。
-
MINDIR - MindSpore Native Intermediate Representation for Anf。一种MindSpore模型的中间表示格式。推荐的输出文件后缀是”.mindir”。
-
-
kwargs (dict) - 配置选项字典。
-
enc_key (byte) - 用于加密的字节类型密钥,有效长度为16、24或者32。
-
enc_mode (Union[str, function]) - 指定加密模式,当设置 enc_key 时启用。
-
对于 ‘AIR’和 ‘ONNX’格式的模型,当前仅支持自定义加密导出。
-
对于 ‘MINDIR’格式的模型,支持的加密选项有: ‘AES-GCM’, ‘AES-CBC’, ‘SM4-CBC’和用户自定义加密算法。默认值:
"AES-GCM"
。 -
关于使用自定义加密导出的详情,请查看 教程。
-
-
dataset (Dataset) - 指定数据集的预处理方法,用于将数据集的预处理导入MindIR。
-
obf_config (dict) - 模型混淆配置选项字典。
-
type (str) - 混淆类型,目前支持动态混淆,即 ‘dynamic’ 。
-
obf_ratio (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串
"small"
、"medium"
、"large"
。"small"
、"medium"
、"large"
分别对应于 0.1、0.3、0.6。 -
customized_func (function) - 在自定义函数模式下需要设置的Python函数,用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型,且是恒定的,用户可以参考不透明谓词进行设置(请查看 动态混淆教程 中的 my_func())。如果设置了 customized_func ,那么在使用 load 接口导入模型的时候,需要把这个函数也传入。
-
obf_random_seed (int) - 混淆随机种子,是一个取值范围为(0, 9223372036854775807]的整数,不同的随机种子会使模型混淆后的结构不同。如果用户设置了 obf_random_seed ,那么在部署混淆模型的时候,需要在调用 mindspore.nn.GraphCell 接口中传入 obf_random_seed 。需要注意的是,如果用户同时设置了 customized_func 和 obf_random_seed ,那么后一种模式将会被采用。
-
-
custom_func (function) - 用户自定义的导出策略的函数。该函数会在网络导出时,对模型使用该函数进行自定义处理。需要注意,当前仅支持对 format 为 MindIR 的文件使用 custom_func ,且自定义函数仅接受一个代表 MindIR 文件 Proto 对象的入参。当使用 custom_func 对模型进行修改时,需要保证修改后模型的正确性,否则可能导致模型加载失败或功能错误。默认值:
None
。
-
model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。
3.2 加载
已有的MindIR模型可以方便地通过load
接口加载,传入nn.GraphCell
即可进行推理。nn.GraphCell
仅支持图模式。
mindspore.nn.GraphCell(graph, params_init=None, obf_random_seed=None)
运行从MindIR加载的计算图。
此功能仍在开发中。目前 GraphCell 不支持修改图结构,在导出MindIR时只能使用shape和类型与输入相同的数据。
【参数】
-
graph (FuncGraph) - 从MindIR加载的编译图。
-
params_init (dict) - 需要在图中初始化的参数。key为参数名称,类型为字符串,value为 Tensor 或 Parameter。如果参数名在图中已经存在,则更新其值;如果不存在,则忽略。默认值:
None
。 -
obf_random_seed (Union[int, None]) - 用于动态混淆保护的混淆随机种子。动态混淆是一种模型保护方法,可以参考 mindspore.obfuscate_model() 。如果导入的 graph 是一个经过混淆的模型,那么须提供 obf_random_seed 。 obf_random_seed 的取值范围是(0, 9223372036854775807]。默认值:
None
。
mindspore.set_context(mode=mindspore.GRAPH_MODE)
graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)
这里时间改了一下,之前差8个小时。