文章目录
该篇笔记整理自余庭嵩的讲解。
存储和装载(序列化与反序列化)
基本概念
模型在训练的时候,各个参数都是存储在内存中的,但是内存不具备长久存储数据的功能,所以就需要将内存中的数据搬到硬盘上进行存储,以备后续用途。而模型的保存与加载,就叫做序列化与反序列化。
序列化:将内存中的模型保存到硬盘中,以二进制序列的形式进行存储。
反序列化:将硬盘中的二进制序列,反序列化到内存中形成对象,由此就能使用模型了。
pytorch中的序列化与反序列化
torch.save()
主要参数
obj:可以是模型,张量,parameter,dict等等,只要python中的一切皆是对象
f:输出路径
使用方法
1、保存整个Module
torch.save(net, path)
2、保存模型参数
state_dict = net.state_dict()
torch.save(state_dict, path)
这两个方法有什么区别呢?具体需要看两个被存储的数据结构。
法1:Module中有8个有序字典去管理,同时还有其他定义的东西,整个模型进行保存会占内存,属于懒方法。
法2:由于模型是通过学习得到的最新的可学习参数,所以方法2就把这些可学习参数保存下来,等下次使用模型的时候直接调取这些可学习参数到模型当中。module下的方法state_dict()会自动将所有可学习的参数调取出来,以字典的形式返回。
torch.load()
主要参数
f:文件路径,对应save中的f
map_location:定义存储在什么位置,GPU上的模型保存如果用普通的方法是不能load进来的
动手实验
首先定义一个网络LeNet,对其参数用自定义的initialize函数改动一次之后进行存储,分别以整模型存储和state_dict形式存储。
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
class LeNet2(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2</