torch.jit.save(net, ‘model.pth’) 和 torch.save(net.state_dict(), ‘model.pth’) 两者在保存模型时的区别
torch.jit.save(net, ‘model.pth’):
这个语句将保存经过 TorchScript 脚本化的模型 net 到名为 ‘model.pth’ 的文件中。
保存的是完整的 TorchScript 模型,包含模型的结构和权重参数。
因此,加载后的模型可以直接用于推理,无需重新定义模型结构。
通过 torch.jit.load(‘model.pth’) 可以加载 TorchScript 模型。
torch.save(net.state_dict(), ‘model.pth’):
这个语句将只保存 PyTorch 模型 net 的状态字典(state dict)到名为 ‘model.pth’ 的文件中。
状态字典是一个 Python 字典,包含了模型的权重参数。它不包含模型的结构信息。
因此,加载模型时,需要先创建一个与原始模型结构相同的模型对象, 然后使用 model.load_state_dict(torch.load(‘model.pth’)) 将权重参数加载到模型中。