PyTorch笔记:如何保存与加载checkpoints

https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html

保存和加载checkpoints很有帮助。
为了保存checkpoints,必须将它们放在字典对象里,然后使用torch.save()来序列化字典。一个通用的PyTorch做法时使用.tar拓展名保存checkpoints。
加载时,首先需要初始化模型和优化器,然后使用torch.load()加载定义

import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

保存checkpoints

# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

加载

model = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()
### Python DeepSeek 和 Jupyter Notebook 使用教程 #### 安装必要的库 为了能够在Jupyter Notebook环境中使用DeepSeek,首先需要确保已经安装了所需的依赖项。可以按照如下方式来设置开发环境: ```bash pip install jupyter deepseek mmdet torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 ``` 上述命令会安装`jupyter`, `deepseek`以及用于目标检测的`mmdet`框架及其依赖项[^2]。 #### 启动Jupyter Notebook并创建新笔记本 启动Jupyter Notebook服务之后,在浏览器中打开对应的链接,并新建一个Python 3内核的Notebook文件: ```bash jupyter notebook ``` 这一步骤允许用户在一个基于Web界面的交互式编程环境中编写和运行代码片段[^4]。 #### 加载预训练模型进行推理 下面是在Jupyter Notebook单元格内的具体操作实例,展示如何加载配置文件、权重参数初始化detector对象并对图片执行推断过程: ```python from mmdet.apis import init_detector, inference_detector config_file = 'configs/dino_x/dino-x_1.5b_fpn_36e_vit.py' checkpoint_file = 'checkpoints/dino-x_1.5b.pth' model = init_detector(config_file, checkpoint_file, device='cuda:0') result = inference_detector(model, 'goods_image.jpg') print(result.pred_instances) ``` 这段脚本展示了怎样利用MMDetection中的API接口完成图像识别任务,并打印预测结果的信息。 #### 利用DeepSeek-R1插件增强生产力 对于特定应用场景下的代码自动生成需求,可以通过集成DeepSeek-R1插件的方式提高效率。该工具能够根据给定的任务描述快速生成高质量的基础代码结构,减少重复劳动的同时也降低了错误发生的可能性。 #### 对话记忆技巧的应用 当涉及到较为复杂的项目时,合理运用工作区的记忆特性可以帮助更好地管理和追踪不同阶段的工作进展。例如设定约束条件以保证代码风格的一致性和可读性: ```plaintext # 使用工作区记忆功能 /remember 项目使用Python3.9 + Django4.2 /constraint 必须通过PEP8严格校验 ``` 这样的做法有助于团队成员之间共享一致的技术栈标准,同时也便于后续维护工作的开展[^3]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值