之前想学习保存和加载模型的代码,在知乎上看到一个回答,发现两行代码就可以搞定,于是兴冲冲的加上了:
torch.save(model, "model.pth.tar")
model_dict=torch.load("model.pth.tar")
然后就大胆的去训练了,结果训练结束,准备load
时,发现load
得到的结果,就只有模型的结构,参数完全没保存下来…
(哎,当时看到答主说这种方式是保存了整个网络,就以为整个网络必然包括参数啊,谁知道仅仅是结构)
于是换了一种方式:
checkpoint = {
"model_struct": model,
"model_param": model.state_dict(),
"model_cfg": config}
torch.save(checkpoint, “model.ckpt")
这种方式是自己建立一个字典checkpoint
,然后分别保存模型结构model_struct
、模型参数model_param
和相关配置model_cfg
,然后保存为.ckpt
文件(至于为何.pth.tar
和.ckpt
到底有什么不一样,暂时还不清楚)
这次的教训就是,单保存模型是不能保存网络参数的,需要调用模型的.state_dict()
属性将参数拿出来