模型下载与保存
1.导入包
import torch
import torch.onnx as onnx
import torchvision.models as models
2.保存并下载模型权重
PyTorch模型将学习到的参数存储在内部状态字典(称为state_dict)中。这些参数可以通过torch.save方法保存:
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
下载:
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()#若有dropout or batch normalization(BN)需加model.eval保证BN和dropout不发生变化,pytorch框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层影响结果。如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()以启用 BatchNormalization 和 Dropout。
或者直接通过model保存和下载
#保存
torch.save(model, 'model.pth')
#下载
model = torch.load('model.pth')
3.onnx用于保存模型,ONNX是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的深度学习框架(如Pytorch, MXNet)可以采用相同格式存储模型数据。简而言之,ONNX是一种便于在各个主流深度学习框架中迁移模型的中间表达格式。
input_image = torch.zeros((1,3,224,224))
onnx.export(model, input_image, 'model.onnx')