一、只保存网络中的参数
保存:
torch.save(model.state_dict(), save_fp)
加载的时候需要先初始化一个模型,然后把文件中的参数恢复。
train_weights = torch.load(model_fp)
model = Model()
model.load_state_dict(model_weights)
这里load得到的是变量类型为OrderedDict(),也就是网络中的参数集合。
二、保存网络结构和参数
保存:
torch.save(model, save_fp)
加载:
model = torch.load(model_fp)
这里load的到的一个对象,类型是<class Model>