深度学习模型的加载和保存(.pth .ckpt)

一        .pth类型

(该类型可以保存和加载模型或者模型参数)

1 保存和加载整个模型

import torch
import torch.nn as nn
import os

# path
path = r"D:\program\draft"
module_pth = os.path.join(path, 'module.pth')

module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
# save whole module
torch.save(module, module_pth)
# load whole module
new_module = torch.load(module_pth, weights_only=False)

2 保存和加载部分模型

import torch
import torch.nn as nn
import os

# path
path = r"D:\program\draft"
module_pth = os.path.join(path, 'module.pth')

module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
# save part module
torch.save(module[0], module_pth)
# load part module
new_module_0 = torch.load(module_pth, weights_only=False)
new_module_1 = nn.Linear(3, 1)
new_module = nn.ModuleList([new_module_0, new_module_1])

3 保存和加载模型所有参数

import torch
import torch.nn as nn
import os

# path
path = r"D:\program\draft"
weights_pth = os.path.join(path, 'weights.pth')

module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
# save whloe weights
torch.save(module.state_dict(), weights_pth)
# load whole weights
new_module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
new_module.load_state_dict(torch.load(weights_pth, weights_only=True))

4 保存和加载模型部分参数

import torch
import torch.nn as nn
import os

# path
path = r"D:\program\draft"
weights_pth = os.path.join(path, 'weights.pth')

module = nn.ModuleList([nn.Linear(4, 3), nn.Linear(3, 1)])
# save part weights
torch.save(module[0].state_dict(), weights_pth)
# load part weights
new_module_0 = nn.Linear(4, 3)
new_module_0.load_state_dict(torch.load(weights_pth, weights_only=True))
new_module_1 = nn.Linear(3, 1)
new_module = nn.ModuleList([new_module_0, new_module_1])

二        .ckpt类型

(可以保存模型参数,学习率,优化器,loss值等等)

保存checkpoint

import torch

checkpoint = {
    'epoch': epoch,                     
    'model_state_dict': model.state_dict(), 
    'optimizer_state_dict': optimizer.state_dict()
}

# save checkpoint
torch.save(checkpoint, 'model_checkpoint.ckpt')

仅保存模型参数

torch.save(model.state_dict(), 'model_weights.ckpt')

加载checkpoint(常用于继续训练)

import torch

model = nn.Linear(4, 3)
optimizer = torch.optim.Adam(model.parameters())

# load checkpoint
checkpoint = torch.load('model_checkpoint.ckpt')

# load lr optimizer weights loss
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

仅加载参数(常用于推理)

​
import torch

model = nn.Linear(4, 3)
optimizer = torch.optim.Adam(model.parameters())

# load checkpoint
checkpoint = torch.load('model_checkpoint.ckpt')

# load weights
model.load_state_dict(checkpoint['model_state_dict'])

后续可能还会补充,欢迎大家批评指正,交流学习

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值