PyTorch框架训练的几种模型区别

14 篇文章 0 订阅
8 篇文章 0 订阅

PyTorch系列文章目录



前言

在PyTorch中,.pt、.pth和.pth.tar都是用于保存训练好的模型的文件格式,它们之间的主要区别如下:

.pt文件是PyTorch 1.6及以上版本中引入的新的模型文件格式,它可以保存整个PyTorch模型,包括模型结构、模型参数以及优化器状态等信息。.pt文件是一个二进制文件,可以通过torch.save()函数来保存模型,以及通过torch.load()函数来加载模型。

.pth文件是PyTorch旧版本中使用的模型文件格式,它只保存了模型参数,没有保存模型结构和其他相关信息。.pth文件同样是一个二进制文件,可以通过torch.save()函数来保存模型参数,以及通过torch.load()函数来加载模型参数。

.pth.tar文件是一个压缩文件,它包含一个.pth文件以及其他相关信息,比如模型结构、优化器状态、超参数等。.pth.tar文件可以通过Python的标准库tarfile来解压,然后通过torch.load()函数来加载模型。

总的来说,.pt文件是最新的、最全面的模型保存格式,可以保存整个PyTorch模型,包括模型结构、参数、优化器状态等信息。.pth文件只保存了模型参数,而.pth.tar文件则是在.pth基础上加入了一些元数据信息,可以方便地保存和加载整个模型状态。在实际应用中,我们可以根据需要选择适合自己的模型保存格式。


一、.pt模型使用介绍

.pt模型文件是PyTorch框架中保存模型权重的文件格式,其结构包含以下几个部分:
Header:文件开头的一段信息,包含了PyTorch版本、模型结构等元数据信息。
State dictionary:模型的权重数据,以Python的字典形式保存。每个键对应了模型的一个参数名,值则是对应的权重矩阵或向量。
Optimizer state:如果模型使用了优化器,那么这里保存了优化器的状态信息,包括当前的学习率、动量等参数。
Other metadata:保存了一些附加的元数据信息,比如模型训练时使用的超参数、训练数据集的统计信息等。
要解读.pt模型文件的信息,可以使用PyTorch提供的torch.load()函数来加载模型文件,然后可以通过访问字典中的键值对来获取模型的权重和其他信息。例如,可以使用以下代码加载模型文件并查看模型结构和权重:

import torch
model = torch.load('model.pt')
print(model)

该代码会输出模型的结构和权重信息,可以通过访问字典中的键值对来获取具体的权重数值。例如,可以使用以下代码获取模型中名为’conv1.weight’的卷积层权重矩阵:

weights = model['conv1.weight']
print(weights)

这样就可以查看模型文件中保存的权重信息,并进一步用于模型的部署或微调等操作。

二、.pth模型使用介绍

Pytorch目前成为学术界最流行的DL框架,没有之一。很大程度上,简洁直观地操作有关。模型的保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。

首先,要清楚几个函数:torch.save,torch.load,state_dict(),load_state_dict()。
先举最简单的例子:

import torch

model = torch.load('my_model.pth')
torch.save(model, 'new_model.pth')

上面的代码非常直观,一载一存。但是有一个问题,这样保存的pth文件直接包含了整个模型的结构。当你需要灵活加载模型参数时,比如只加载部分参数,那么这种情况保存的pth文件读取进来还得额外解析出“参数文件”。

如果想更灵活对待咱们训练好的模型参数,咱们可以使用下面这个方法。pytorch把所有的模型参数用一个内部定义的dict进行保存,自称为“state_dict”。这个所谓的state_dict就是不带模型结构的模型参数了~
咱们的加载和保存就发生了一点微妙的变化:

import torch
model = MyModel() # init your model class, build the graph shape
state_dict = torch.load('model_state_dict.pth')
model.load_state_dict(state_dict)
torch.save(model.state_dict(), 'model_state_dict1.pth')

比较上面两段代码,咱们可以有一下结论:

pth文件既可能保存了模型的图结构,也有可能没保存;
加载没保存图结构的pth时,需要先初始化模型结构,即把架子搭好;
在保存模型的时候,如果不想保存图结构,可以单独保存model.state_dict()

实验
脚本如下:

import torch
import torchvision.models as models

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'only_weights.pth')

model_state_dict = torch.load('only_weights.pth')
model1 = models.vgg16() # describe the graph shape
model1.load_state_dict(model_state_dict)
model1.eval()

torch.save(model1, 'whole_model.pth')

model2 = torch.load('whole_model.pth')
model2.eval()

# model3 = torch.load('only_weights.pth')
# model3.eval()    # Error

model3切换到eval()模式就会报错,原因是model3只包含weights而缺乏图结构~

三、.pth.tar模型使用介绍

由于为我的特定应用程序重新训练初始模型需要大量计算资源,我想使用已经重新训练的模型。
此模型保存为 .pth.tar文件。
我希望能够首先加载这个模型。到目前为止,我已经能够弄清楚我必须使用以下内容:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')

这似乎有效,因为 print(model)打印出大量数字和其他值,我认为这些值是权重和偏差的值。
在此之后,我需要能够用它对图像进行分类。我一直无法弄清楚这一点。我必须如何格式化图像?图像是否应该转换为数组?在此之后,我必须如何将输入数据传递给网络?

如果您有 .pth.tar文件,您可以加载它,从而覆盖已定义模型的参数值。

这意味着保存/加载模型的一般过程如下:
编写您的网络定义(即您的 nn.Module 对象)
以您想要的方式训练或以其他方式更改网络参数
使用 torch.save 保存参数
当您想使用该网络时,请使用 nn.Module 的相同定义对象首先实例化 pytorch 网络
然后使用 torch.load 覆盖网络参数的值

这是一个超短的 mwe:

四、.pkl模型

保存

torch.save({
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
}, 'filename.pth.tar')

加载

checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

总结

https://blog.csdn.net/Cretheego/article/details/128789192

  • 10
    点赞
  • 61
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一种基于Python的深度学习框架,能够提供高效的张量操作和动态构建计算图的能力。下面是PyTorch训练模型和使用模型的原理流程: 1. 数据准备:首先需要准备好训练集和测试集,并对数据进行预处理,例如标准化、归一化等操作。 2. 模型定义:使用PyTorch定义模型,可以选择使用现成的预训练模型,也可以自己定义模型。 3. 损失函数定义:选择合适的损失函数,例如交叉熵损失函数、均方误差等。 4. 优化器定义:选择合适的优化器,例如Adam、SGD等,用于更新模型参数。 5. 训练模型:将数据输入模型,计算损失函数,并根据优化器对模型参数进行更新。 6. 模型评估:使用测试集评估模型性能,计算模型的准确率、精确率、召回率等指标。 7. 模型保存:将训练好的模型保存下来,以便后续使用。 8. 使用模型:使用保存的模型对新数据进行预测或分类。 在使用PyTorch进行深度学习任务时,通常需要使用以下几个库: - torch:PyTorch的核心库,提供张量操作等基础功能。 - torchvision:提供了一些常用的计算机视觉数据集和模型。 - torchtext:提供了一些常用的自然语言处理数据集和模型。 - torchsummary:提供了一个方便的方式来查看模型的结构和参数数量。 总之,PyTorch是一种灵活、易于使用和扩展的深度学习框架,可以帮助开发者快速构建、训练和部署深度学习模型

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值