1 torch.save()
for example:
①保存整个模型:torch.save(model,'---.pth')
model:需要保存的模型,即神经网络的model类;
'---.pth':保存地址及文件名,如:'logs/Epoch1.pth',指:保存在logs文件夹下,保存的pth格式的文件名为Epoch1.pth。
②只保存训练好的权重:torch.save(model.state_dict(), '---.pth')
model.state_dict():model:class类,指只保存模型训练好的权重;后面同理。
例子:torch.save(net.state_dict(), 'logs/Epoch1.pth')
2 torch.load()
for example:
net.load_state_dict(torch.load('logs/Epoch1.pth'))
这个例子与1 torch.save()中的②相对应,net:我起的model的class类名,load_state_dict:下载之前保存的权重;torch.load:下载指令函数;('logs/Epoch1.pth'):指文件夹及文件名。
3 保存权重推荐:
保存: torch.save(model.state_dict(), PATH)
下载:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
使用torch.save()保存state_dict,能够方便模型的加载。因此推荐使用这种方式进行模型保存。使用model.eval()来固定dropout和归一化层,否则每次推理会生成不同的结果。注意,load_state_dict()需要传入字典对象,因此需要先反序列化state_dict再传入load_state_dict()
4 保存整个模型推荐:
保存: torch.save(model, PATH)
下载
model = torch.load(PATH)
model.eval()