模型保存
model = network() mindspore.save_checkpoint(model, "model.ckpt")
要加载模型权重,需要先创建相同模型的实例,然后使用`load_checkpoint`和`load_param_into_net`方法加载参数。
model = network() param_dict = mindspore.load_checkpoint("model.ckpt") param_not_load, _ = mindspore.load_param_into_net(model, param_dict) # `param_not_load`是未被加载的参数列表,为空时代表所有参数均加载成功。 print(param_not_load)
MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示
model = network() inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32)) mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
加载模型 `nn.GraphCell`仅支持图模式
mindspore.set_context(mode=mindspore.GRAPH_MODE) graph = mindspore.load("model.mindir") model = nn.GraphCell(graph) outputs = model(inputs) print(outputs.shape)