今天这篇文章讨论一下我们进行深度学习时,如何将预测的结果如何转换为对应的标签以及如何将最后的模型进行保存和加载。
预测
为了更清晰地说明整个过程,我们还是以代码来说明一下:
from PIL import Iamge
labels=['cat','fish']
img=Image.open(FileName)
img=transforms(img)
img=img.unsequeeze(0)
prediction=simplenet(img)
prediction=prediction.argmax()
print(labels[prediction])
得到预测的结果很简单,只需要把我们的批次(batch)传入模型。然后要找出有较大概率的类。在这里,可以简单地将张量转换为一个数组,并比较这两个元素,不过通常会有更多元素。PyTorch提供了argmax()函数,这很有用,它会返回张量中最大值得索引。然后使用这个索引访问我们的标签,打印出预测结果。
模型保存
如果你对一个模型的性能很满意,或者由于某个原因需要停止训练,可以使用torch.save()方法采用Python的pickle格式保存模型的当前状态。反过来,我们也可以使用torch.load()方法加载之前保存的一个模型迭代。
所以,保存当前参数和模型结构的代码,如下所示:
torch.save('simplenet','./filedir')
可以使用如下代码进行加载代码:
simplenet=torch.load('./modeldir')
这会把参数以及模型的结构都保存到一个文件中。如果以后某个时间改变了模型的结构,可能就会有问题。由于这个原因,更常见的做法时保存模型的state_dict。这是一个标准的Python dict,其中包含模型中每一层参数的映射。可以保存如下state_dict:
torch.save(model.state_dict(),PATH)
恢复时,首先创建模型的一个实例,再使用load_state_dict。
simplenet=SimpleNet()
simplenet_state_dict=torch.load("./modeldir")
simplenet.load_state_dict(simplenet_state_dict)
这样的好处是,如果以某种方式扩展了模型,可以向load_state_dict提供了一个strict=False参数,为state_dict中确实有的模型层指定相应的参数,而如果所加载的state_dict与模型当前结构相比缺少或增加了某些层,也不会失败。因为这只是一个普通的Python dict 。可以改变键名来适应你的模型,如果要从一个完全不同的模型抽取参数,这会很方便。
注:文章摘选自《基于PyTroch的深度学习》 Ian Pointer著