PyTorch加载保存的模型时,加载模型的三种方式

当使用PyTorch加载保存的模型时,可以将加载模型的代码单独写成一个函数或模块,以便在需要的地方进行调用。下面是一个例子来说明如何单独加载模型:

import torch
import torch.nn as nn

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.fc = torch.nn.Linear(16 * 30 * 30, 10)

    def forward(self, x):
        x = self.conv(x)
        x = torch.relu(x)
        x = x.view(-1, 16 * 30 * 30)
        x = self.fc(x)
        return x  # 返回输出张量

model = Model()


#
# model = TheModelClass()
#  动态图  模型保存权重和模型结构
torch.save(model, "dongtai.pt")
# 动态图  模型保存权重

torch.save(model.state_dict(), "dongtai_state_dict.pth")

# 模型保存,方法三静态图
x = torch.rand(1, 3, 30, 30)
trace_model = torch.jit.trace(model, x)
torch.jit.save(trace_model, "jingtai.pt")

在上述例子中,我们定义了模型类 TheModelClass 。该函数接受一个模型路径 model_path 作为参数,并返回加载后的模型对象。

  1. 动态图加载模型结构和权重:
import torch
from test import Model# 不导入或同级下找不到会有问题
#保存包含类的文件的路径,该文件在加载时使用
model = torch.load("dongtai.pt")
model.eval()

"dongtai.pt" 中。加载模型时,我们使用 torch.load 方法加载模型文件,得到完整的模型结构和权重。

  1. 动态图加载权重:
import torch
import torch.nn as nn
from test import Model# 不导入或同级下找不到会有问题

model = Model()

# 加载模型权重
weights = torch.load("dongtai_state_dict.pth")
model.load_state_dict(weights)
  1. 静态图加载权重:
import torch
import torch.nn as nn
# 加载静态图模型
model_ji = torch.jit.load("jingtai.pt")

在上述例子中,我们同样定义了一个模型类 TheModelClass 并实例化了一个模型对象 model。我们创建了一个随机输入张量 x,并使用 torch.jit.trace 方法运行模型并记录运行路径。然后,我们使用 torch.jit.save 方法将运行路径保存到文件中,这种方法会自动记录模型中节点的数据流动路径。加载模型时,我们可以直接使用 torch.jit.load 方法加载保存的静态图模型文件,因为该模型已经记录了节点的权重和数据流动路径,所以只需将数据输入模型,即可得到输出结果。

这些例子说明了如何使用不同的方法来保存和加载PyTorch模型。根据具体的需求,选择适合的方法来保存和加载模型。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值