深度学习--快速有效实现模型数据的保存和读取

目录

前提

一、为什么保存模型的数据?

二、模型数据保存

1、保存模型参数

2、保存完整模型

三、模型数据读取

四、总结


前提

本次对模型数据的保存和读取进行了分离,分别配属于不同的文件,便于管理和使用,相关文件自行下载参考

模型数据:

class CNN(nn.Module):
    def __init__(self):         # 输入大小 (3, 256, 256)
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(  #将多个层组合成一起。
            nn.Conv2d(          #2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据
                in_channels=3,  # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数),
                out_channels=16,# 要得到几多少个特征图,卷积核的个数
                kernel_size=5,  # 卷积核大小,5*5
                stride=1,       # 步长
                padding=2,      # 一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好。那padding改如何设计呢?建议stride为1,kernel_size = 2*padding+1
            ),                  # 输出的特征图为 (16, 256, 256)
            nn.ReLU(),  # relu层
            nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域), 输出结果为: (16, 128, 128)
        )
        self.conv2 = nn.Sequential(  #输入 (16, 128, 128)
            nn.Conv2d(16, 32, 5, 1, 2),  # 输出 (32, 128, 128)
            nn.ReLU(),  # relu层
            nn.Conv2d(32, 32, 5, 1, 2), # 输出 (32, 128, 128)
            nn.ReLU(),
            nn.MaxPool2d(2),  # 输出 (32, 64, 64)
        )

        self.conv3 = nn.Sequential(  #输入 (32, 64, 64)
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.ReLU(),  # 输出 (64, 64, 64)
        )

        self.out = nn.Linear(64 * 64 * 64, 20)  # 全连接层得到的结果

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)# 输出 (64,64, 32, 32)
        x = x.view(x.size(0), -1)  # flatten操作,结果为:(batch_size, 64 * 32 * 32)
        output = self.out(x)
        return output

一、为什么保存模型的数据?

在许多应用场景中,往往需要在不同的时间、地点和环境中使用同一个模型。如果模型参数存储在硬盘上,就可以随时加载并使用,非常方便。而且重新训练一个深度学习模型可能需要花费数小时,甚至数天的时间,这取决于模型的复杂性和数据集的大小。因此,保存模型可以避免在每个项目或每次需要使用模型时重新训练。

在运行深度学习模型时,往往需要大量的内存(RAM)。如果模型的参数存储在硬盘上而不是内存中,可以大大减少内存的使用,这在内存有限的系统上尤其有用。在深度学习的迁移学习中,一个已经在一个任务上训练好的模型被用作新任务的基础模型。对新任务只需要训练部分特定的层,这样可以大大节省计算资源并提高效率。保存模型使得这种训练方式变得可能。

二、模型数据保存

1、保存模型参数
torch.save(model.state_dict(), path)

保存模型的参数,model.state_dict()返回一个包含模型所有参数的字典对象,包括卷积层的权重和偏置,全连接层的权重和偏置等。这个字典对象可以被加载回模型以进行后续的训练或者评估。

2、保存完整模型
torch.save(model, 'best.pt')

将完整的模型数据保存到‘best.pt’文件中,实现模型的存储。

三、模型数据读取

在生成保存文件后,如果想要读取该数据,但又不想在原文件中进行读取,可以使用下述方法

import torch

model = torch.load('best.pt') #加载模型数据文件
model.eval() #固定模型数据和参数,防止后面被修改
print(model)

使用 'torch.load('best.pt')' 加载了之前保存的‘best.pt’模型'。model.eval()' 是将模型设置为评估模式,这在模型训练完毕并加载之后非常常见,它主要用于模型的预测或者评估阶段。最后,'print(model)' 是打印出模型的结构或者参数等信息。

注:此代码可单独运行,但必须与所加载的文件处于同一目录下

读取结果展示:

四、总结

torch.loadtorch.save是PyTorch库中的两个用于模型加载和保存的函数。torch.save函数用于将模型、张量或其他PyTorch对象保存到文件或路径中,可以选择不同的序列化方式。torch.load函数用于从文件中或路径中加载已保存的模型、张量或其他PyTorch对象。可以在不同文件中实现对模型数据的保存和加载。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值