一 .pth类型
(该类型可以保存和加载模型或者模型参数)
1 保存和加载整个模型
import torch
import torch.nn as nn
import os
# path
path = r"D:\program\draft"
module_pth = os.path.join(path, 'module.pth')
module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
# save whole module
torch.save(module, module_pth)
# load whole module
new_module = torch.load(module_pth, weights_only=False)
2 保存和加载部分模型
import torch
import torch.nn as nn
import os
# path
path = r"D:\program\draft"
module_pth = os.path.join(path, 'module.pth')
module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
# save part module
torch.save(module[0], module_pth)
# load part module
new_module_0 = torch.load(module_pth, weights_only=False)
new_module_1 = nn.Linear(3, 1)
new_module = nn.ModuleList([new_module_0, new_module_1])
3 保存和加载模型所有参数
import torch
import torch.nn as nn
import os
# path
path = r"D:\program\draft"
weights_pth = os.path.join(path, 'weights.pth')
module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
# save whloe weights
torch.save(module.state_dict(), weights_pth)
# load whole weights
new_module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
new_module.load_state_dict(torch.load(weights_pth, weights_only=True))
4 保存和加载模型部分参数
import torch
import torch.nn as nn
import os
# path
path = r"D:\program\draft"
weights_pth = os.path.join(path, 'weights.pth')
module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
# save part weights
torch.save(module[0].state_dict(), weights_pth)
# load part weights
new_module_0 = nn.Linear(4, 3)
new_module_0.load_state_dict(torch.load(weights_pth, weights_only=True))
new_module_1 = nn.Linear(3, 1)
new_module = nn.ModuleList([new_module_0, new_module_1])
二 .ckpt类型
(可以保存模型参数,学习率,优化器,loss值等等)
保存checkpoint
import torch
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}
# save checkpoint
torch.save(checkpoint, 'model_checkpoint.ckpt')
仅保存模型参数
torch.save(model.state_dict(), 'model_weights.ckpt')
加载checkpoint(常用于继续训练)
import torch
model = nn.Linear(4, 3)
optimizer = torch.optim.Adam(model.parameters())
# load checkpoint
checkpoint = torch.load('model_checkpoint.ckpt')
# load lr optimizer weights loss
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
仅加载参数(常用于推理)
import torch
model = nn.Linear(4, 3)
optimizer = torch.optim.Adam(model.parameters())
# load checkpoint
checkpoint = torch.load('model_checkpoint.ckpt')
# load weights
model.load_state_dict(checkpoint['model_state_dict'])