网络模型的保存与读取

网络模型的保存于与读取

方法1:

1.1 如何保存网络模型

首先,创建一个py文件,model_save.py

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16,"vgg16_model1_pth")

运行结束后我们会在我们左侧的文件出现vgg16_model1_pth这个文件
用这种方法保存,不仅保存了网络模型,也保存了网络模型中的相关参数
在这里插入图片描述

1.2 如何读取网络模型

新建一个py文件,model_load.py

import torch

model = torch.load("vgg16_model1_pth")
print(model)

输出:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

方法二

2.1:如何保存网络模型

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

#torch.save(vgg16,"vgg16_model1_pth")
torch.save(vgg16.state_dict(),"vgg16_model2_pth")

也会在左侧形成一个vgg16_model2_pth文件
只保存了模型的参数,占用空间更小,官方推荐方式

2.2:如何读取网络模型

读取方式与方法一 一样,但是输出为字典类型的数据

import torch

# 方式2-> 保存方式2,加载模型
model = torch.load("vgg16_model2_pth")  # 加载出来的是字典类型的数据
print(model)

F:\Anaconda3\envs\pytorch\python.exe D:/Python/learn_torch/model_load.py
OrderedDict([('features.0.weight', tensor([[[[ 3.9726e-02, -4.0263e-02,  5.2152e-02],
          [ 3.5984e-02, -4.6239e-02, -2.4924e-02],
          [-9.6867e-03,  1.2961e-02, -4.5731e-02]],

         [[ 1.9925e-03,  3.6464e-02,  5.6411e-02],
          [-9.0956e-02, -3.6801e-02, -7.3917e-02],
          [ 3.6363e-02, -4.5585e-02, -8.2003e-03]],

         [[-1.1151e-01, -2.4694e-02, -3.4446e-02],
          [-5.4018e-02,  7.9030e-02,  1.1468e-01],
          [ 6.1839e-02, -8.7451e-02,  2.8596e-03]]],


        [[[-6.4775e-02,  5.2936e-03, -1.8106e-02],
          [-4.0254e-02, -8.5685e-02, -7.8011e-02],
          [ 1.1739e-02, -7.9629e-02,  6.6174e-02]],

         [[-1.1657e-01,  3.5422e-02,  6.2663e-02],
          [ 3.0534e-02,  6.9120e-03,  3.3340e-03],
          [-1.5356e-01,  7.2058e-02,  4.7606e-02]],

         [[-1.2942e-01, -3.5475e-02,  9.7374e-02],
          [-1.3898e-02, -2.5312e-02,  6.3060e-02],
          [ 5.4231e-04,  1.4181e-02,  8.3530e-02]]],


        [[[-1.5726e-03,  6.0129e-02, -2.5256e-02],
          [-8.2932e-02,  9.2577e-02,  1.8457e-02],
          [-5.7204e-02, -5.2296e-02,  8.6386e-02]],

         [[-3.1392e-02,  1.2295e-01, -6.2096e-03],
          [-1.6034e-02,  3.0497e-03,  5.9402e-02],
          [-7.5480e-02, -6.9659e-02, -1.2263e-02]],

         [[ 6.5706e-05, -4.6442e-02,  6.1466e-02],
          [ 3.6150e-02,  3.6947e-02, -9.4802e-02],
          [ 7.0997e-02,  1.2181e-02,  3.3660e-03]]],
          .....................................
          ....................................
          ..................................

从上述输出结果中得到的结果是字典类型,其中参数的值也一起输出来了,如果想要查看具体的网络结构,需要这样


import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_model2_pth"))  # 输出完整的模型结构,与第一种方式输出的模型结构相同
print(vgg16)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值