【深度之眼】【Pytorch打卡第15天】:模型保存与加载

任务

任务简介

了解序列化与反序列化

详细说明

学习pytorch中的模型保存与加载,也常称为序列化与反序列化。
了解序列化与反序列化的概念,进而对模型的保存与加载有深刻的认识,同时介绍pytorch中模型保存与加载的方法函数。


知识点

一、序列化与反序列化

序列化:变量从内存中变成可存储或传输的过程称之为序列化
反序列化:把变量内容从序列化的对象重新读到内存里称之为反序列化

1. Pytorch的序列化–torch.save

  • 主要参数:
    • obj:对象
    • f:输出路径

2. Pytorch的反序列化–torch.load

  • 主要参数
    • f:文件路径
    • map_location:指定存放位置, cpu or gpu

二、模型保存与加载的两种方式

保存模型

模型的保存方式有两种,一种是把整个模型的所有东西都保存下来,这种方法占用的磁盘空间较大,但是保存的信息较为全面;另一种是只保存模型的关键参数,其他的不保存,这种方法占用的磁盘空间较小,但是只保存了模型的关键参数信息,官方推荐第二种保存方式

法1: 保存整个Module –torch.save(net, path)
net=LeNet2(classes=2019)

path_model = "./model.pkl"

# 保存整个模型
torch.save(net, path_model)
法2: 保存模型参数–net.state_dict()torch.save(state_dict , path)
net=LeNet2(classes=2019)

path_state_dict = "./model_state_dict.pkl"

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

实例演示:

# -*- coding: utf-8 -*-
"""
# @brief      : 模型的保存
"""
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools2 import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)


net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

加载模型

加载整个模型–torch.load(path_model)
path_model = "./model.pkl"
net_load = torch.load(path_model)

print(net_load)
加载模型参数–torch.load(path_state_dict)
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)

print(state_dict_load.keys())

net_new = LeNet2(classes=2019)
net_new.load_state_dict(state_dict_load)

实例演示:

# -*- coding: utf-8 -*-
"""
# @brief      : 模型的加载
"""
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools2 import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值