Pytorch 保存和加载模型
Pytorch 保存和加载模型后缀:.pt 和.pth
保存整个模型:
torch.save(model,'save.pt')
只保存训练好的权重:
torch.save(model.state_dict(), 'save.pt')
加载模型:
pretrained_dict = torch.load("save.pt")
只加载模型参数:
model.load_state_dict(torch.load("save.pt")) #model.load_state_dict()函数把加载的权重复制到模型的权重中去
加载某一层的训练的到的参数
conv1_weight_state = torch.load('save.pt')['conv1.weight']