保存和加载模型
'''https://pytorch-cn.readthedocs.io/zh/latest/notes/serialization/'''
import torch
import torchvision
model = torchvision.models.resnet18(pretrained=True)
"""---------------保存和加载模型参数(推荐)----------------------------"""
torch.save(model.state_dict(),"temp_file/resnet18")
model.state_dict() :
保存模型的状态字典
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.
Example:
>>> module.state_dict().keys()
model.state_dict()['conv1.weight'].shape
torch.Size([64, 3, 7, 7])
"""加载模型 和对应的参数"""
the_model = torchvision.models.resnet18()
the_model.load_state_dict(torch.load("temp_file/resnet18"))
<All keys matched successfully>
“”“保存和加载整个模型(不推荐)”“”
torch.save(the_model, PATH)
the_model = torch.load(PATH)
ps :在这种情况下,序列化的数据被绑定到特定的类和固定的目录结构,所以当在其他项目中使用时,或者在一些严重的重构器之后它可能会以各种方式break