【pytorch模型加载和保存】

10 篇文章 0 订阅
本文详细介绍了如何在PyTorch中保存和加载模型,包括单独保存参数、加载参数以及保存整个模型。以一个多层感知机为例,展示了从数据准备到模型训练和推理的完整流程。
摘要由CSDN通过智能技术生成

pytorch模型加载和保存

在PyTorch中,模型的保存和加载是非常常见的操作。

1. 模型保存

保存模型参数(state_dict):

这是推荐的方式,仅保存模型的参数,不包括优化器状态或其他训练相关的状态。

import torch
from your_model_module import YourModelClass

# 创建模型实例
model = YourModelClass()

# 假设模型已经训练完成,并且权重已经更新
# ...

# 定义保存路径
torch.save(model.state_dict(), 'model_weights.pth')

# 如果需要保存模型结构和参数一起
# 可以通过torch.jit.trace或torch.jit.script方法将模型转换为ScriptModule,
# 然后保存,例如:
# traced_script_module = torch.jit.trace(model, example_input)
# torch.jit.save(traced_script_module, 'model_traced.pth')

2. 加载模型参数

# 加载模型参数时,先创建模型实例,然后加载保存的state_dict
model = YourModelClass()
model.load_state_dict(torch.load('model_weights.pth'))

# 注意,加载前确保模型架构与保存的权重匹配
# 如果是在其他地方定义的模型,可以这样加载:
# model = YourModelClass(*args, **kwargs)  # 这里的*args和**kwargs是模型初始化所需的参数
# model.load_state_dict(torch.load('model_weights.pth'), strict=True)  # strict=True表示严格检查键值对是否完全匹配
# 如果模型有未加载的参数或者有些参数不再需要,可以设置strict=False来跳过这些检查

3. 保存并加载整个模型(包括结构)

如果你想保存完整的模型结构,可以直接保存整个模型对象:

# 保存整个模型(包括结构)
torch.save(model, 'entire_model.pth')

# 加载整个模型
loaded_model = torch.load('entire_model.pth')
# 此时不需要调用load_state_dict,因为模型已经被完整地恢复了

这种方式可能在不同的环境中复现性较差,因为它依赖于模型类的可用性和环境中的其他因素。因此,推荐的做法仍然是分开保存模型结构(通常作为代码的一部分)和模型参数(.pth文件)。当需要在不同环境中加载模型时,首先按照相同的结构重新构建模型,然后再加载参数。

pytorch 模型训练和推理步骤

使用PyTorch构建、训练和推理模型的详细步骤及代码示例,以一个简单的多层感知机(MLP)为例,数据集为随机生成的数据。

1. 导入所需库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

2. 定义自定义数据集类 (假设我们有输入和标签)

class RandomDataset(Dataset):
    def __init__(self, n_samples=1000, input_dim=10, output_dim=1):
        self.x = torch.randn(n_samples, input_dim)
        self.y = torch.randn(n_samples, output_dim)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

3. 定义模型类

class MLP(nn.Module):
    def __init__(self, input_dim=10, hidden_dim=50, output_dim=1):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

4. 构建模型并设置优化器

model = MLP()
criterion = nn.MSELoss()  # 假设是回归问题,使用均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用SGD优化器

# 或者也可以选择其他优化器,例如Adam
# optimizer = optim.Adam(model.parameters(), lr=0.001)

5. 数据加载器

dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 训练前通常还会对数据进行标准化处理等预处理工作,此处为了简单起见省略

6. 模型训练

num_epochs = 100
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        # 前向传播
        outputs = model(inputs)

        # 计算损失
        loss = criterion(outputs, targets)

        # 反向传播与优化
        optimizer.zero_grad()  # 清空梯度缓存
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新权重

    # 每个epoch结束后打印损失值
    print(f"Epoch {epoch + 1}: Loss: {loss.item():.4f}")

7. 模型推理(预测)

# 在训练完成后,可以使用模型进行推理
test_data = torch.randn(10, 10)  # 假设这是新的测试样本
with torch.no_grad():  # 推理时不需要计算梯度
    predictions = model(test_data)
print("Predictions:", predictions)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

stsdddd

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值