Pytorch 网络模型的保存与读取

Pytorch 网络模型的保存与读取

方法一:

模型保存:该方法将保存模型的网络结构+参数

import torch
import torchvision
#创建模型
vgg16=torchvision.vgg16( pretrained=False)
#模型保存
torch.save(vgg16, "vgg16_method1.pth")

模型读取:

import torch
model= torch.load("vgg16_method1.pth")
print(model)

方法二:官方推荐使用该方法

模型保存:该方法仅保存模型参数

import torch
import torchvision
#创建模型
vgg16=torchvision.vgg16( pretrained=False)
#模型保存
torch.save(vgg16.state_dict(), "vgg_method2.pth")

模型读取:

vgg16=torchvision.model.vgg16(pretrained=False)
vgg16.load_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

感谢:
https://www.bilibili.com/video/BV1hE411t7RN/?p=26&spm_id_from=pageDriver&vd_source=5b6e0605c1ed0f1db9c92503dd5994e0

要加载一个预训练模型,你需要知道模型的结构和权重参数。通常情况下,这些信息会以不同的文件形式保存。 要加载模型结构,你需要用到PyTorch中的`torch.nn.Module`。如果你有模型结构的代码,你可以直接使用该代码来创建模型对象。如果你已经将模型结构保存到文件中,则可以使用`torch.load()`函数加载该文件,然后使用其中存储的模型对象生成模型。 要加载模型权重,你需要使用`torch.load()`函数加载包含权重的文件。当你加载权重时,你需要确保模型结构与权重是兼容的,否则加载权重将会失败。一般情况下,你需要先创建一个模型对象,然后再加载权重。你可以使用`model.load_state_dict()`函数将权重加载到模型中,其中`model`是你要加载权重的模型对象。 下面是一个简单的示例代码,演示了如何加载模型结构和权重: ```python import torch import torchvision.models as models # 加载预训练模型的结构 model = models.resnet18(pretrained=False) # 加载预训练模型的权重 state_dict = torch.load('resnet18.pth') model.load_state_dict(state_dict) ``` 在这个例子中,我们使用了PyTorch中内置的ResNet-18模型作为示例,首先创建了一个模型对象,然后从文件中加载了预训练模型的权重。请注意,我们在加载权重之前将`pretrained`参数设置为`False`,以确保不会自动下载预训练模型
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值