pytorch中的模型保存与加载

本文介绍了PyTorch中模型的保存和加载过程,包括torch.save()和torch.load()的使用方法,以及如何进行断点续训。详细阐述了整模型存储和模型参数存储的区别,并通过实例演示了加载state_dict时的注意事项,强调了模型结构一致性的重要性。
摘要由CSDN通过智能技术生成


该篇笔记整理自余庭嵩的讲解。

存储和装载(序列化与反序列化)

基本概念

模型在训练的时候,各个参数都是存储在内存中的,但是内存不具备长久存储数据的功能,所以就需要将内存中的数据搬到硬盘上进行存储,以备后续用途。而模型的保存与加载,就叫做序列化与反序列化。
序列化:将内存中的模型保存到硬盘中,以二进制序列的形式进行存储。
反序列化:将硬盘中的二进制序列,反序列化到内存中形成对象,由此就能使用模型了。

pytorch中的序列化与反序列化

torch.save()

主要参数
obj:可以是模型,张量,parameter,dict等等,只要python中的一切皆是对象
f:输出路径

使用方法

1、保存整个Module

torch.save(net, path)

2、保存模型参数

state_dict = net.state_dict()
torch.save(state_dict, path)

这两个方法有什么区别呢?具体需要看两个被存储的数据结构。
法1:Module中有8个有序字典去管理,同时还有其他定义的东西,整个模型进行保存会占内存,属于懒方法。
法2:由于模型是通过学习得到的最新的可学习参数,所以方法2就把这些可学习参数保存下来,等下次使用模型的时候直接调取这些可学习参数到模型当中。module下的方法state_dict()会自动将所有可学习的参数调取出来,以字典的形式返回。

torch.load()

主要参数
f:文件路径,对应save中的f
map_location:定义存储在什么位置,GPU上的模型保存如果用普通的方法是不能load进来的

动手实验

首先定义一个网络LeNet,对其参数用自定义的initialize函数改动一次之后进行存储,分别以整模型存储和state_dict形式存储

import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2</
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值