pytorch模型加载和保存
在PyTorch中,模型的保存和加载是非常常见的操作。
1. 模型保存
保存模型参数(state_dict):
这是推荐的方式,仅保存模型的参数,不包括优化器状态或其他训练相关的状态。
import torch
from your_model_module import YourModelClass
# 创建模型实例
model = YourModelClass()
# 假设模型已经训练完成,并且权重已经更新
# ...
# 定义保存路径
torch.save(model.state_dict(), 'model_weights.pth')
# 如果需要保存模型结构和参数一起
# 可以通过torch.jit.trace或torch.jit.script方法将模型转换为ScriptModule,
# 然后保存,例如:
# traced_script_module = torch.jit.trace(model, example_input)
# torch.jit.save(traced_script_module, 'model_traced.pth')
2. 加载模型参数
# 加载模型参数时,先创建模型实例,然后加载保存的state_dict
model = YourModelClass()
model.load_state_dict(torch.load('model_weights.pth'))
# 注意,加载前确保模型架构与保存的权重匹配
# 如果是在其他地方定义的模型,可以这样加载:
# model = YourModelClass(*args, **kwargs) # 这里的*args和**kwargs是模型初始化所需的参数
# model.load_state_dict(torch.load('model_weights.pth'), strict=True) # strict=True表示严格检查键值对是否完全匹配
# 如果模型有未加载的参数或者有些参数不再需要,可以设置strict=False来跳过这些检查
3. 保存并加载整个模型(包括结构)
如果你想保存完整的模型结构,可以直接保存整个模型对象:
# 保存整个模型(包括结构)
torch.save(model, 'entire_model.pth')
# 加载整个模型
loaded_model = torch.load('entire_model.pth')
# 此时不需要调用load_state_dict,因为模型已经被完整地恢复了
这种方式可能在不同的环境中复现性较差,因为它依赖于模型类的可用性和环境中的其他因素。因此,推荐的做法仍然是分开保存模型结构(通常作为代码的一部分)和模型参数(.pth
文件)。当需要在不同环境中加载模型时,首先按照相同的结构重新构建模型,然后再加载参数。
pytorch 模型训练和推理步骤
使用PyTorch构建、训练和推理模型的详细步骤及代码示例,以一个简单的多层感知机(MLP)为例,数据集为随机生成的数据。
1. 导入所需库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
2. 定义自定义数据集类 (假设我们有输入和标签)
class RandomDataset(Dataset):
def __init__(self, n_samples=1000, input_dim=10, output_dim=1):
self.x = torch.randn(n_samples, input_dim)
self.y = torch.randn(n_samples, output_dim)
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
3. 定义模型类
class MLP(nn.Module):
def __init__(self, input_dim=10, hidden_dim=50, output_dim=1):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
4. 构建模型并设置优化器
model = MLP()
criterion = nn.MSELoss() # 假设是回归问题,使用均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01) # 使用SGD优化器
# 或者也可以选择其他优化器,例如Adam
# optimizer = optim.Adam(model.parameters(), lr=0.001)
5. 数据加载器
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
# 训练前通常还会对数据进行标准化处理等预处理工作,此处为了简单起见省略
6. 模型训练
num_epochs = 100
for epoch in range(num_epochs):
for inputs, targets in dataloader:
# 前向传播
outputs = model(inputs)
# 计算损失
loss = criterion(outputs, targets)
# 反向传播与优化
optimizer.zero_grad() # 清空梯度缓存
loss.backward() # 计算梯度
optimizer.step() # 更新权重
# 每个epoch结束后打印损失值
print(f"Epoch {epoch + 1}: Loss: {loss.item():.4f}")
7. 模型推理(预测)
# 在训练完成后,可以使用模型进行推理
test_data = torch.randn(10, 10) # 假设这是新的测试样本
with torch.no_grad(): # 推理时不需要计算梯度
predictions = model(test_data)
print("Predictions:", predictions)