PyTorch-网络模型的保存和读取

1. 模型的保存

方法一:保存模型的结构+模型的参数

陷阱:需要让文件访问到你自己的模型定义方式,可以用import的方式引入先前的模型定义。

model_save.py

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights=None)
# 保存方式一
torch.save(vgg16, 'vgg16_method1.pth')

方法二:保存模型的参数(官方推荐,文件小一些)

model_save.py

import torch
import torchvision

vgg16 = torchvision.models.vgg16(weights=None)
# 保存方式二 保存网络模型的参数
torch.save(vgg16.state_dict(), 'vgg16_method2.pth')

2. 模型的加载

model_load.py(对应方法一的)

import torch

# 加载模型
model = torch.load('vgg16_method1.pth')
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Process finished with exit code 0

model_load.py(对应方法二的)

model2 = torch.load('vgg16_method2.pth')
print(model2)

OrderedDict([('features.0.weight', tensor([[[[ 0.0588, -0.0743, -0.1424],
          [-0.0034,  0.0577,  0.0819],
          [-0.0233, -0.0427,  0.1821]],

         [[ 0.0583, -0.0244,  0.0121],
          [ 0.0243, -0.0532,  0.0252],
          [-0.0372,  0.0098,  0.0754]],

         [[ 0.0480,  0.0094,  0.0544],
          [-0.0291, -0.0081,  0.0834],
          [-0.0282,  0.0537, -0.0363]]],

......

若要恢复网络模型: 

import torch
import torchvision

# 加载模型
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load('vgg16_method2.pth'))
print(vgg16)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)

......

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
要加载一个预训练模型,你需要知道模型的结构和权重参数。通常情况下,这些信息会以不同的文件形式保存。 要加载模型结构,你需要用到PyTorch中的`torch.nn.Module`。如果你有模型结构的代码,你可以直接使用该代码来创建模型对象。如果你已经将模型结构保存到文件中,则可以使用`torch.load()`函数加载该文件,然后使用其中存储的模型对象生成模型。 要加载模型权重,你需要使用`torch.load()`函数加载包含权重的文件。当你加载权重时,你需要确保模型结构与权重是兼容的,否则加载权重将会失败。一般情况下,你需要先创建一个模型对象,然后再加载权重。你可以使用`model.load_state_dict()`函数将权重加载到模型中,其中`model`是你要加载权重的模型对象。 下面是一个简单的示例代码,演示了如何加载模型结构和权重: ```python import torch import torchvision.models as models # 加载预训练模型的结构 model = models.resnet18(pretrained=False) # 加载预训练模型的权重 state_dict = torch.load('resnet18.pth') model.load_state_dict(state_dict) ``` 在这个例子中,我们使用了PyTorch中内置的ResNet-18模型作为示例,首先创建了一个模型对象,然后从文件中加载了预训练模型的权重。请注意,我们在加载权重之前将`pretrained`参数设置为`False`,以确保不会自动下载预训练模型
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值