当使用PyTorch加载保存的模型时,可以将加载模型的代码单独写成一个函数或模块,以便在需要的地方进行调用。下面是一个例子来说明如何单独加载模型:
import torch
import torch.nn as nn
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.fc = torch.nn.Linear(16 * 30 * 30, 10)
def forward(self, x):
x = self.conv(x)
x = torch.relu(x)
x = x.view(-1, 16 * 30 * 30)
x = self.fc(x)
return x # 返回输出张量
model = Model()
#
# model = TheModelClass()
# 动态图 模型保存权重和模型结构
torch.save(model, "dongtai.pt")
# 动态图 模型保存权重
torch.save(model.state_dict(), "dongtai_state_dict.pth")
# 模型保存,方法三静态图
x = torch.rand(1, 3, 30, 30)
trace_model = torch.jit.trace(model, x)
torch.jit.save(trace_model, "jingtai.pt")
在上述例子中,我们定义了模型类 TheModelClass
。该函数接受一个模型路径 model_path
作为参数,并返回加载后的模型对象。
- 动态图加载模型结构和权重:
import torch
from test import Model# 不导入或同级下找不到会有问题
#保存包含类的文件的路径,该文件在加载时使用
model = torch.load("dongtai.pt")
model.eval()
"dongtai.pt"
中。加载模型时,我们使用 torch.load
方法加载模型文件,得到完整的模型结构和权重。
- 动态图加载权重:
import torch
import torch.nn as nn
from test import Model# 不导入或同级下找不到会有问题
model = Model()
# 加载模型权重
weights = torch.load("dongtai_state_dict.pth")
model.load_state_dict(weights)
- 静态图加载权重:
import torch
import torch.nn as nn
# 加载静态图模型
model_ji = torch.jit.load("jingtai.pt")
在上述例子中,我们同样定义了一个模型类 TheModelClass
并实例化了一个模型对象 model
。我们创建了一个随机输入张量 x
,并使用 torch.jit.trace
方法运行模型并记录运行路径。然后,我们使用 torch.jit.save
方法将运行路径保存到文件中,这种方法会自动记录模型中节点的数据流动路径。加载模型时,我们可以直接使用 torch.jit.load
方法加载保存的静态图模型文件,因为该模型已经记录了节点的权重和数据流动路径,所以只需将数据输入模型,即可得到输出结果。
这些例子说明了如何使用不同的方法来保存和加载PyTorch模型。根据具体的需求,选择适合的方法来保存和加载模型。