4.3节通过代码清单4-2给出了完整的测试代码,与训练代码相比,测试代码相对要简单一些,整体而言只需要导入模型后执行模型的前向计算即可,不涉及损失的反向传播、参数更新等过程。
通过4.3节的介绍,相信你对整个测试代码有了一个直观的认识,接下来我将详细介绍测试代码的内容,主要包含模型导入、数据读取和预测输出三个部分。
4.4.1 模型导入
MXNet框架在训练深度学习模型过程中会保存两个主要文件,即“.params文件”和“.json文件”,前者是模型的参数,后者是模型的网络结构,因此在导入模型时需要同时导入“.params文件”和“.json文件”。在模型训练部分我们介绍了Module对象的fit() 方法,当我们指定fit() 方法的epoch_end_callback参数后,fit() 方法训练模型就能将训练好的“.params文件”和“.json文件”保存在指定目录下。
在测试代码中,可通过如下代码先配置导入模型所需的参数,然后调用load_model()函数导入模型:
model_prefix = "output/LeNet"
index