针对仅推理场景,把参数直接加载到网络中,以便后续的推理验证。
-
加载模型
from mindspore.train.serialization import load_checkpoint, load_param_into_net
# 加载已经保存的用于测试的模型
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
# 加载参数到网络中
load_param_into_net(net, param_dict)
-
验证模型
我们使用生成的模型进行单个图片数据的分类预测,具体步骤如下:
# 定义分类类别
classes = [
"Zero",
"One",
"Two",
"Three",
"Four",
"Fives",
"Six",
"Seven",
"Eight",
"Nine",
]
# image为测试图片,label为测试图片的实际分类
image, label = test_data[0][0], test_data[0][1]
# 使用函数model.predict预测image对应分类
pred = model.predict(image)
predicted, actual = classes[pred[0].argmax(0)], classes[label]
# 输出预测分类与实际分类
print(f'Predicted: "{predicted}", Actual: "{actual}"')
-
运行结果示例如下:
Predicted: "Eight", Actual: "Eight"