https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/beginner/save_load.ipynb
基本介绍 || 快速入门 || 张量 Tensor || 数据集 Dataset || 数据变换 Transforms || 网络构建 || 函数式自动微分 || 模型训练 || 保存与加载 || 使用静态图加速
保存与加载
上一章节主要介绍了如何调整超参数,并进行网络模型训练。在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,本章节我们将介绍如何保存与加载模型。
import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor
def network():
model = nn.SequentialCell(
nn.Flatten(),
nn.Dense(28*28, 512),
nn.ReLU(),
nn.Dense(512, 512),
nn.ReLU(),
nn.Dense(512, 10))
return model
保存和加载模型权重
保存模型使用save_checkpoint
接口,传入网络和指定的保存路径:
model = network()
mindspore.save_checkpoint(model, "model.ckpt")
要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpoint
和load_param_into_net
方法加载参数。
model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
[]
保存和加载MindIR
除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export
接口直接将模型保存为MindIR(当前仅支持严格图模式)。
mindspore.set_context(mode=mindspore.GRAPH_MODE, jit_syntax_level=mindspore.STRICT)
model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
[WARNING] ME(3879:281472967741456,MainProcess):2024-08-06-03:50:36.535.088 [mindspore/context.py:1104] For 'context.set_context' in Ascend backend, the backend is already initialized, please set it before the definition of any Tensor and Parameter, and the instantiation and execution of any operation and net, otherwise the settings may not take effect.
[WARNING] ME(3879:281472967741456,MainProcess):2024-08-06-03:50:36.572.417 [mindspore/context.py:1104] For 'context.set_context' in Ascend backend, the backend is already initialized, please set it before the definition of any Tensor and Parameter, and the instantiation and execution of any operation and net, otherwise the settings may not take effect.
[WARNING] ME(3879:281472967741456,MainProcess):2024-08-06-03:50:36.901.581 [mindspore/context.py:1104] For 'context.set_context' in Ascend backend, the backend is already initialized, please set it before the definition of any Tensor and Parameter, and the instantiation and execution of any operation and net, otherwise the settings may not take effect.
MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。
已有的MindIR模型可以方便地通过load
接口加载,传入nn.GraphCell
即可进行推理。
nn.GraphCell
仅支持图模式。
graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)
(1, 10)