pytorch保存和加载模型

核心

在保存和加载模型方面主要有三个核心的方法:

  • torch.save:将对象序列化保存到磁盘中,该方法原理是基于python中的pickle来序列化,各种Models,tensors,dictionaries 都可以使用该方法保存。保存的模型文件名可以是.pth, .pt, .pkl

    save(
        obj: object,
        f: FILE_LIKE,
        pickle_module: Any = pickle,
        pickle_protocol: int = DEFAULT_PROTOCOL,
        _use_new_zipfile_serialization: bool = True
    ) -> None:
    
    • obj:保存的对象
    • f:一个类似文件的对象(必须实现写入和刷新)或字符串或操作系统。包含文件名的类似路径对象
    • pickle_module:用于挑选元数据和对象的模块
    • pickle_protocol:可以指定以覆盖默认协议
  • torch.load:采用 pickle 将反序列化的对象从存储中加载进内存。

    def load(
        f: FILE_LIKE,
        map_location: MAP_LOCATION = None,
        pickle_module: Any = None,
        *,
        weights_only: bool = False,
        **pickle_load_args: Any
    ) -> Any:
    
    • f:类文件对象 (返回文件描述符)或一个保存文件名的字符串
    • map_location:一个函数或字典规定如何映射存储设备,torch.device对象
    • pickle_module:用于 unpickling 元数据和对象的模块 (必须匹配序列化文件时的 pickle_module )
  • torch.nn.Module.load_state_dict:采用一个反序列的state_dict()方法将模型的参数加载到模型结构上。

    load_state_dict(state_dict, strict=True)
    
    • state_dict:保存 parameters 和 persistent buffers 的字典
    • strict:可选,bool型。state_dict 中的 key 是否和 model.state_dict() 返回的 key 一致。

序列化 (Serialization)是将对象的状态信息转换为可以存储或传输的形式的过程。. 在序列化期间,对象将其当前状态写入到临时或持久性存储区。以后,可以通过从存储区中读取或反序列化对象的状态,重新创建该对象。

状态字典state_dict

首先需要了解一下state_dict(),这个翻译成中文就是状态字典,在神经网络中可以认为是模型上训练出来的模型参数,也就是权重和偏置值。在Pytorch中,定义网络模型是通过继承torch.nn.Module来实现的。其网络模型中包含可学习的参数(weights, bias, 和一些登记的缓存如batchnorm’s running_mean 等)。模型内部的可学习参数可通过两种方式进行调用:

  1. 通过model.parameters()这个生成器来访问所有参数。
  2. 通过model.state_dict()来为每一层和它的参数建立一个映射关系并存储在字典中,其键值由每个网络层和其对应的参数张量构成。

除模型外,优化器对象(torch.optim)同样也有一个状态字典,包含的优化器状态信息以及使用的超参数。由于状态字典属于Python 字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都比较简单。

import torch.optim as optim

from torchvision.models import resnet18

# 定义模型,并使用预训练权重
model = resnet18(pretrained=True)

# 定义优化器,随机梯度优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 模型的 state_dict
print(type(model.state_dict()))  # <class 'collections.OrderedDict'>
for key, value in model.state_dict().items():
    print(key, "\t", value.size())

# 模型的model.parameters()
print(model.parameters())  # 生成器对象 <generator object Module.parameters at 0x00000276B61FC820>
# 两种方式查看,一种转成list: list(model.parameters())
# 另一种使用for循环
for para in model.parameters():
    print(para.size())

# 优化器的 state_dict
for key, value in optimizer.state_dict().items():
    print(key, value)

上述使用了ResNet18,并使用了预训练权重,然后

  1. 通过.state_dict()打印模型每个网络层名字和参数
  2. 通过.parameters()打印模型的每个网络层的参数
  3. 通过.state_dict()打印优化器每个网络层名字和参数

加载/保存 状态字典(state_dict)

保存一个模型的时候,只需要保存训练模型的可学习参数即可。通过torch.save() 来保存模型的状态字典。

加载模型时,通过torch.load()来加载模型的状态字典,且通过.load_state_dict()将状态字典加载到模型上。

import torch
from torchvision.models import resnet18

# 定义模型,并使用预训练权重
model = resnet18(pretrained=True)

# 保存模型的state_dict
PATH = "./model.pth"  # 可以是pth,pt,pkl后缀文件名
torch.save(model.state_dict(), PATH)

# =======加载模型=======
# 1. 定义模型
model = resnet18()
# 2. 加载状态字典进内存
state_dict = torch.load(PATH)

# 3. 将状态字典中的可学习参数加载到模型上使用
model.load_state_dict(state_dict)

加载/保存整个模型

使用该方法,就不只是保存模型的状态字典,而是保存整个模型对象(包括模型结构和状态字典等)。

这种保存模型的做法是采用 Python 的 pickle 模块来保存整个模型,该方法的缺点是 pickle 不保存模型类别,而是保存一个包含该类的文件的路径。所以在加载的时候可能会报错Can't get attribute 'TheModelClass' on <module '__main__' from "...."

import torch
from torchvision.models import resnet18

# 定义模型,并使用预训练权重
model = resnet18(pretrained=True)

# 保存模型的全部
PATH = "./model.pth"  # 可以是pth,pt,pkl后缀文件名
torch.save(model, PATH)
import torch

# 需要注意定义一下这个类,否则会报错
from torchvision.models import resnet18

# 加载就行
PATH = "./model.pth" # 路径
model = torch.load(PATH)

加载和保存一个通用的检查点(Checkpoint)

将训练过程中的中断信息进行保存,使其能够继续训练或者用于预测。其中保存的信息包括但不限于:

  • epoch
  • 模型的state_dict
  • 优化器的state_dict
  • loss函数

保存如下:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

保存代码参考如下:

import torch
from torchvision.models import resnet18
import torch.optim as optim

# 定义模型,并使用预训练权重
model = resnet18(pretrained=True)

# 定义优化器,随机梯度优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 定义损失函数,交叉熵损失函数
loss = torch.nn.CrossEntropyLoss()

epoch = 50

# 保存Checkpoint
PATH = "./model.pth"
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, PATH)

加载代码参考如下:

import torch
from torchvision.models import resnet18
import torch.optim as optim

# 定义模型
model = resnet18()

# 定义优化器,随机梯度优化器
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

# 加载checkpoint
PATH = "./model.pth"
checkpoint = torch.load(PATH)

# 模型的状态字典加载到模型上
model.load_state_dict(checkpoint['model_state_dict'])
# 将优化器的状态字典加载到优化器上
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

epoch = checkpoint['epoch']
loss = checkpoint['loss']

使用旧模型的参数来预热新模型(Warmstaring Model)

主要用于fine-tune微调,加快模型的训练速度和提高精度,通常是加载预训练模型的部分网络参数作为模型的初始化参数以达到该效果。该方法使用与之前的类似,不过在将状态字典加载到模型上时需要设置strict=False,表示当模型中的key与状态字典中的key不匹配时暂时跳过不管,但当模型中的key与状态字典中的key匹配时,pytorch就会尝试帮我们加载参数。

因此,如果要使用字典来载入权重的参数必须要保证模型参数的名称一模一样,也就是模型的key与状态字典中的key匹配

这里提供两种方式来进行修改参数:

加载状态字典,并将状态字典全部加载到模型上,通过修改加载后的模型结构,以此来修改修改结构的参数。

from torchvision.models import resnet18
import torch

model = resnet18(num_classes=1000)  # 定义模型
PATH = "./model.pth"  # 定义状态字典的文件路径
device = torch.device("cpu")  # 定义设备

# 加载PATH文件中的状态字典,并放入device中
weights_dict = torch.load(PATH, map_location=device)
model.load_state_dict(weights_dict)  # 整体载入全部参数

# 修改最后一层
in_channel = model.fc.in_features  # 原始模型最后一层为model1.fc = nn.Liean(in_features,1000)
model.fc = torch.nn.Linea(in_channel, 5)  # 将原始模型model1.fc进行修改

通过修改,删除,增加状态字典中的key和value进而实现模型加载不同的状态字典。

from torchvision.models import resnet18
import torch

model = resnet18(num_classes=5)  # 定义模型
PATH = "./model.pth"  # 定义状态字典的文件路径
device = torch.device("cpu")  # 定义设备

# 加载PATH文件中的状态字典,并放入device中
weights_dict = torch.load(PATH, map_location=device)

# ========== 方案一 ==========
# 判断模型中的每层tensor是否一样,如果一样保留作为状态字典,否则舍弃
load_weights_dict = {k: v for k, v in weights_dict.items() if model.state_dict()[k].numel() == v.numel()}

# 将自己筛选过后的状态字典加载到模型上
print(model.load_state_dict(load_weights_dict, strict=False))

# ========== 方案二 ==========
# 删除自定义的字典
del weights_dict["fc.weight"]
del weights_dict["fc.bias"]

missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
print(f"missing_keys={missing_keys}, unexpected_keys={unexpected_keys}")
# missing_keyes表示模型中没有被状态字典初始化的key集
# unexpected_keys表示状态字典中的权重不在模型中的state_dict()内
# 注意各种state_dict()都是有顺序的字典,都是一一对应的

参考:Serialization semantics — PyTorch master documentation

Saving and Loading Models — PyTorch Tutorials 1.12.1+cu102 documentation

  • 5
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值