PyTorch学习笔记(十一) ---- PyTorch保存和加载模型

转载请注明作者和出处: http://blog.csdn.net/john_bh/


保存和加载模型的三个核心功能:

  • torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。
  • torch.load:使用pickle的unpickling功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。
  • torch.nn.Module.load_state_dict:使用反序列化函数 state_dict 来加载模型的参数字典。

1.state_dict

在PyTorch中,torch.nn.Module模型的可学习参数(即权重和偏差)包含在模型的参数中,(使用model.parameters()可以进行访问)。 state_dict是Python字典对象,它将每一层映射到其参数张量。注意,只有具有可学习参数的层(如卷积层,线性层等)的模型 才具有state_dict这一项目标优化torch.optim也有state_dict属性,它包含有关优化器的状态信息,以及使用的超参数。

因为state_dict的对象是Python字典,所以它们可以很容易的保存、更新、修改和恢复,为PyTorch模型和优化器添加了大量模块。

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.pool= nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
        
    def forward(self,x):
        x=self.pool(F.relu(self.conv1(x)))
        x=self.pool(F.relu(self.conv2(x)))
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        
        return x
    
model=TheModelClass() # 初始化模型

optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9) # 初始化优化器

# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor,'\t',model.state_dict()[param_tensor].size())
    
# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name,'\t',optimizer.state_dict()[var_name])

输出:

Model's state_dict:
conv1.weight  	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2824824504464, 2824824505328, 2824824507704, 2824824507416, 2824824506624, 2824824505688, 2824824506264, 2824824508280, 2824824507056, 2824824507128]}]

2.保存和加载推理模型

2.1 保存/加载state_dict(推荐使用)

当保存好模型用来推断的时候,只需要保存模型学习到的参数,使用torch.save()函数来保存模型state_dict,它会给模型恢复提供 最大的灵活性。

  • 保存模型:
    torch.save(model.state_dict(), PATH)
    
  • 加载模型:
    model = TheModelClass(*args, **kwargs) #这里需要重新模型结构
    model.load_state_dict(torch.load(PATH))
    model.eval()
    

在 PyTorch 中最常见的模型保存使‘.pt’或者是‘.pth’作为模型文件扩展名。

注意:在运行推理之前,务必调用model.eval()去设置 dropout 和 batch normalization 层为评估模式。如果不这么做,可能导致 模型推断结果不一致。

load_state_dict()函数只接受字典对象,而不是保存对象的路径。这就意味着在你传给load_state_dict()函数之前,你必须反序列化 你保存的state_dict。例如,你无法通过 model.load_state_dict(PATH)来加载模型。

2.2 保存/加载完整模型

  • 保存模型:
    torch.save(model, PATH) #保存整个model的状态
    
  • 加载模型:
    # 模型类必须在此之前被定义
    model=torch.load(PATH) #这里已经不需要重构模型结构了,直接load就可以
    model.eval()
    

以 Python `pickle 模块的方式来保存模型。这种方法的缺点是序列化数据受限于某种特殊的类而且需要确切的字典结构。这是因为pickle无法保存模型类本身。相反,它保存包含类的文件的路径,该文件在加载时使用。 因此,当在其他项目使用或者重构之后,你的代码可能会以各种方式中断。

注意:在运行推理之前,务必调用model.eval()去设置 dropout 和 batch normalization 层为评估模式。如果不这么做,可能导致 模型推断结果不一致。

2.3 保存/加载 .t7

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式

  • 保存模型:
    print('===> Saving models...')
    state = {
        'state': model.state_dict(),
        'epoch': epoch # 将epoch一并保存
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/autoencoder.t7')
    
    保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。
  • 加载模型:
    pytorch读取数据使用的方法和使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。
    print('===> Try resume from checkpoint')
    if os.path.isdir('checkpoint'):
        try:
            checkpoint = torch.load('./checkpoint/autoencoder.t7')
            model.load_state_dict(checkpoint['state'])        # 从字典中依次读取
            start_epoch = checkpoint['epoch']
            print('===> Load last checkpointdata')
        except FileNotFoundError:
            print('Can\'t found autoencoder.t7')
    else:
        start_epoch = 0
        print('===> Start from scratch')
    

2.4 保存模型文件类型 .pt , .pth, .pkl, .t7的区别

经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save()函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth.pkl。用相同的 torch.save()语句保存出来的模型文件没有什么不同。

在pytorch官方的文档/代码里,有用.pt的,也有用.pth的。一般惯例是使用.pth,但是官方文档里貌似.pt更多,而且官方也不是很在意固定用一种。

torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存整个神经网络的的结构信息和模型参数信息,save的对象是网络模型;二是只保存神经网络的训练模型参数,save的对象是net.state_dict()。如2.1和2.2中所述。
假设有一个训练好的模型名叫net,则

  • torch.save(net, ‘net.pth’) # 保存整个神经网络的结构和模型参数
  • torch.save(net, ‘net.pkl’) # 同上
  • torch.save(net.state_dict(), ‘net_params.pth’) # 只保存神经网络的模型参数
  • torch.save(net.state_dict(), ‘net_params.pkl’) # 同上
    如果你是使用torch.save方法来进行模型参数的保存,那保存文件的后缀其实没有任何影响,结果都是一样的,很多.pkl的文件也是用torch.save保存下来的,和.pth文件一模一样的。

不过,如果应用场景不是在这里,这两种格式的文件还是有区别的,
.pkl文件是python里面保存文件的一种格式,如果直接打开会显示一堆序列化的东西,其实就是以二进制形式存储的,如果去read这些文件,则需要用’rb’而不是’r’模式。
.pth文件则有不同的应用,Python在遍历已知的库文件目录过程中,如果见到一个.pth 文件,就会将文件中所记录的路径加入到 sys.path 设置中,于是 .pth 文件指明的库也就可以被 Python 运行环境找到了。

不管pkl文件还是pth文件,都是以二进制形式存储的,没有本质上的区别,你用pickle这个库去加载pkl文件或pth文件,效果都是一样的。

3. 保存和加载 Checkpoint 用于推理/继续训练

  • 保存模型:
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                ...
                }, PATH)
    
  • 加载模型:
    model = TheModelClass(*args, **kwargs)
    optimizer = TheOptimizerClass(*args, **kwargs)
    
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    model.eval()
    # - or -
    model.train()
    

当保存成 Checkpoint 的时候,可用于推理或者是继续训练,保存的不仅仅是模型的 state_dict 。保存优化器的 state_dict 也很重要, 因为它包含作为模型训练更新的缓冲区和参数。你也许想保存其他项目,比如最新记录的训练损失,外部的torch.nn.Embedding层等等。

要保存多个组件,请在字典中组织它们并使用torch.save()来序列化字典。PyTorch 中常见的保存checkpoint 是使用 .tar 文件扩展名。

要加载项目,首先需要初始化模型和优化器,然后使用torch.load()来加载本地字典。如果你想要恢复训练,请调用model.train()以确保这些层处于训练模式。

4. 在一个文件中保存多个模型

  • 保存模型:
    torch.save({
                'modelA_state_dict': modelA.state_dict(),
                'modelB_state_dict': modelB.state_dict(),
                'optimizerA_state_dict': optimizerA.state_dict(),
                'optimizerB_state_dict': optimizerB.state_dict(),
                ...
                }, PATH)
    
  • 加载模型:
    modelA = TheModelAClass(*args, **kwargs)
    modelB = TheModelBClass(*args, **kwargs)
    optimizerA = TheOptimizerAClass(*args, **kwargs)
    optimizerB = TheOptimizerBClass(*args, **kwargs)
    
    checkpoint = torch.load(PATH)
    modelA.load_state_dict(checkpoint['modelA_state_dict'])
    modelB.load_state_dict(checkpoint['modelB_state_dict'])
    optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
    optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
    
    modelA.eval()
    modelB.eval()
    # - or -
    modelA.train()
    modelB.train()
    

当保存一个模型由多个torch.nn.Modules组成时,例如GAN(对抗生成网络)、sequence-to-sequence (序列到序列模型), 或者是多个模 型融合, 可以采用与保存常规检查点相同的方法。换句话说,保存每个模型的 state_dict 的字典和相对应的优化器。如前所述,可以通 过简单地将它们附加到字典的方式来保存任何其他项目,这样有助于恢复训练。

5.使用在不同模型参数下的热启动模式

  • 保存模型:
    torch.save(modelA.state_dict(), PATH)
    
  • 加载模型:
    modelB = TheModelBClass(*args, **kwargs)
    modelB.load_state_dict(torch.load(PATH), strict=False)
    

在迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见的情况。利用训练好的参数,有助于热启动训练过程,并希望帮助你的模型比从头开始训练能够更快地收敛。

无论是从缺少某些键的 state_dict 加载还是从键的数目多于加载模型的 state_dict , 都可以通过在load_state_dict()函数中将strict参数设置为 False 来忽略非匹配键的函数。

如果要将参数从一个层加载到另一个层,但是某些键不匹配,主要修改正在加载的 state_dict 中的参数键的名称以匹配要在加载到模型中的键即可。

6.通过设备保存/加载模型

6.1 保存到 CPU、加载到 CPU

  • 保存模型:
    torch.save(model.state_dict(), PATH)
    
  • 加载模型:
    device = torch.device('cpu')
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location=device))
    

当从CPU上加载模型在CPU上训练时, 将torch.device(‘cpu’)传递给torch.load()函数中的map_location参数.在这种情况下,使用 map_location参数将张量下的存储器动态的重新映射到CPU设备。

6.2 保存到 GPU、加载到 GPU

  • 保存模型:
    torch.save(model.state_dict(), PATH)
    
  • 加载模型:
    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.to(device)
    # 确保在你提供给模型的任何输入张量上调用input = input.to(device)
    

当在GPU上训练并把模型保存在GPU,只需要使用model.to(torch.device(‘cuda’)),将初始化的 model 转换为 CUDA 优化模型。另外,请 务必在所有模型输入上使用.to(torch.device(‘cuda’))函数来为模型准备数据。请注意,调用my_tensor.to(device)会在GPU上返回my_tensor的副本。 因此,请记住手动覆盖张量:my_tensor= my_tensor.to(torch.device(‘cuda’))。

6.3 保存到 CPU,加载到 GPU

  • 保存模型:
    torch.save(model.state_dict(), PATH)
    
  • 加载模型:
    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
    model.to(device)
    # 确保在你提供给模型的任何输入张量上调用input = input.to(device)
    

在CPU上训练好并保存的模型加载到GPU时,将torch.load()函数中的map_location参数设置为cuda:device_id。这会将模型加载到 指定的GPU设备。接下来,请务必调用model.to(torch.device(‘cuda’))将模型的参数张量转换为 CUDA 张量。最后,确保在所有模型输入上使用 .to(torch.device(‘cuda’))函数来为CUDA优化模型。请注意,调用my_tensor.to(device)会在GPU上返回my_tensor的新副本。它不会覆盖my_tensor。 因此, 请手动覆盖张量my_tensor = my_tensor.to(torch.device(‘cuda’))。

6.4 保存 torch.nn.DataParallel 模型

  • 保存模型:
    torch.save(model.state_dict(), PATH)
    
  • 加载模型:
    # 加载任何你想要的设备
    

torch.nn.DataParallel是一个模型封装,支持并行GPU使用。要普通保存 DataParallel 模型, 请保存model.module.state_dict()。 这样,你就可以非常灵活地以任何方式加载模型到你想要的设备中。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值