任务
任务简介
了解序列化与反序列化
详细说明
学习pytorch中的模型保存与加载,也常称为序列化与反序列化。
了解序列化与反序列化的概念,进而对模型的保存与加载有深刻的认识,同时介绍pytorch中模型保存与加载的方法函数。
知识点
一、序列化与反序列化
序列化:变量从内存中
变成可存储或传输的过程称之为序列化
反序列化:把变量内容从序列化的对象重新读到内存
里称之为反序列化
1. Pytorch
的序列化–torch.save
- 主要参数:
• obj:对象
• f:输出路径
2. Pytorch
的反序列化–torch.load
- 主要参数
• f:文件路径
• map_location:指定存放位置, cpu or gpu
二、模型保存与加载的两种方式
保存模型
模型的保存方式有两种,一种是把整个模型的所有东西都保存下来,这种方法占用的磁盘空间较大,但是保存的信息较为全面;另一种是只保存模型的关键参数,其他的不保存,这种方法占用的磁盘空间较小,但是只保存了模型的关键参数信息,
官方推荐第二种保存方式
法1: 保存整个Module –torch.save(net, path)
net=LeNet2(classes=2019)
path_model = "./model.pkl"
# 保存整个模型
torch.save(net, path_model)
法2: 保存模型参数–net.state_dict()
与torch.save(state_dict , path)
net=LeNet2(classes=2019)
path_state_dict = "./model_state_dict.pkl"
# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
实例演示:
# -*- coding: utf-8 -*-
"""
# @brief : 模型的保存
"""
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools2 import set_seed
class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
def initialize(self):
for p in self.parameters():
p.data.fill_(20191104)
net = LeNet2(classes=2019)
# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"
# 保存整个模型
torch.save(net, path_model)
# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
加载模型
加载整个模型–torch.load(path_model)
path_model = "./model.pkl"
net_load = torch.load(path_model)
print(net_load)
加载模型参数–torch.load(path_state_dict)
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
print(state_dict_load.keys())
net_new = LeNet2(classes=2019)
net_new.load_state_dict(state_dict_load)
实例演示:
# -*- coding: utf-8 -*-
"""
# @brief : 模型的加载
"""
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools2 import set_seed
class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)