pytorch中模型文件.pth的解析

在pytorch进行模型保存的时候,一般有两种保存方式,一种是保存整个模型,另一种是只保存模型的参数。
torch.save(model.state_dict(), "my_model.pth")  # 只保存模型的参数
torch.save(model, "my_model.pth")  # 保存整个模型

保存的模型参数实际上一个字典类型,通过key-value的形式来存储模型的所有参数,本文以自己在实践过程中使用的一个.pth文件为例来说明,使用的是整个模型。

1.1 .pth 文件基本信息的查看

import torch
 
pthfile = r'F:/GNN/graph-rcnn/graph-rcnn/datasets/sg_baseline_ckpt.pth'  #faster_rcnn_ckpt.pth
net = torch.load(pthfile,map_location=torch.device('cpu')) # 由于模型原本是用GPU保存的,但我这台电脑上没有GPU,需要转化到CPU上
 
print(type(net))  # 类型是 dict
print(len(net))   # 长度为 4,即存在四个 key-value 键值对
 
for k in net.keys():
    print(k)      # 查看四个键,分别是 model,optimizer,scheduler,iteration

1.2 模型的四个键值分别详解

(1)net[“model”] 详解
# print(net["model"]) # 返回的是一个OrderedDict 对象
for key,value in net["model"].items():
    print(key,value.size(),sep="   ")
'''运行结果如下:
module.backbone.body.stem.conv1.weight   torch.Size([64, 3, 7, 7])
module.backbone.body.stem.bn1.weight   torch.Size([64])
module.backbone.body.stem.bn1.bias   torch.Size([64])
module.backbone.body.stem.bn1.running_mean   torch.Size([64])
module.backbone.body.stem.bn1.running_var   torch.Size([64])
module.backbone.body.layer1.0.downsample.0.weight   torch.Size([256, 64, 1, 1])
module.backbone.body.layer1.0.downsample.1.weight   torch.Size([256])
.
.
.
module.backbone.body.layer3.22.bn3.weight   torch.Size([1024])
module.backbone.body.layer3.22.bn3.bias   torch.Size([1024])
module.backbone.body.layer3.22.bn3.running_mean   torch.Size([1024])
module.backbone.body.layer3.22.bn3.running_var   torch.Size([1024])
.
.
.
module.rpn.head.conv.bias   torch.Size([1024])
module.rpn.head.cls_logits.weight   torch.Size([15, 1024, 1, 1])
module.rpn.head.cls_logits.bias   torch.Size([15])
module.rpn.head.bbox_pred.weight   torch.Size([60, 1024, 1, 1])
.
.
.
module.roi_heads.box.feature_extractor.head.layer4.0.bn2.running_var   torch.Size([512])
module.roi_heads.box.feature_extractor.head.layer4.0.conv3.weight   torch.Size([2048, 512, 1, 1])
module.roi_heads.box.feature_extractor.head.layer4.0.bn3.weight   torch.Size([2048])
.
.
.
module.roi_heads.relation.predictor.cls_score.weight   torch.Size([51, 2048])
module.roi_heads.relation.predictor.cls_score.bias   torch.Size([51])
'''

总结:键model所对应的值是一个OrderedDict,而这个OrderedDict字典里面又存储着所有的每一层的参数名称以及对应的参数值。

需要注意的是,这里参数名称之所以很长,如:

module.backbone.body.stem.conv1.weight

是因为搭建网络结构的时候采用了组件式的设计,即整个模型里面构造了一个backbone的容器组件,backbone里面又构造了一个body容器组件,body里面又构造了一个stem容器。

(2)net[“optimizer”]详解
# print(net["optimizer"]) # 返回的是一个一般的字典 Dict 对象
for key,value in net["optimizer"].items():
    print(key,type(value),sep="    ")
'''运行结果为:
state    <class 'dict'>
param_groups    <class 'list'>
'''
'''
发现这个这个字典只有两个key,一个是state,一个是param_groups
其中state所对应的值又是一个字典类型,
param_groups对应的值是一个列表
'''

继续往下查看得到

先看一下net[“optimizer”][“param_groups”] 这个列表里面放了一下啥:

groups=net["optimizer"]["param_groups"]
print(groups)
print(len(groups))  # 返回115.即在这个模型中,共有115组
 
'''
[{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644061240]}, 
 {'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644061960]}, 
 {'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644062248]}, 
 {'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644077336]},
.
.
.
{'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566644061960]}, 
 {'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566103171936]}, 
 {'lr': 5.000000000000001e-05, 'weight_decay': 0.0005, 'momentum': 0.9, 'dampening': 0, 'nesterov': False, 'initial_lr': 0.005, 'params': [140566103172008]}
]
'''

这个列表的长度为115,每一个元素又是一个字典。

再看一下net[“optimizer”][“states”] 这个字典里面放了啥:

state=net["optimizer"]["state"]
print(len(state))  # 返回115.即在这个模型中,state共有115组
 
for key,value in state.items():
    print(key,type(value),sep="    ")
'''
140566644061240    <class 'dict'>
140566644061960    <class 'dict'>
140566644062248    <class 'dict'>
140566644077336    <class 'dict'>
.
.
.
140566103171936    <class 'dict'>
140566103172008    <class 'dict'>
'''

这个字典的长度是115,而且和前面的param_groups有着对应关系,每一个元素的键值就是param_groups中每一个元素的params。

继续往深一层看:

print(type(state[140566644061240]))  # 他又是一个字典
for key,value in state[140566644061240].items():
    print(key,value.size(),sep="   ")
 
'''
<class 'dict'>
momentum_buffer   torch.Size([512, 256, 1, 1])
'''
(3)net[“scheduler”] 详解
scheduler=net["scheduler"]  # 返回的依然是一个字典
print(len(scheduler))       # 字典的长度为 7
print(scheduler)
'''
{'milestones': (70000, 90000), 
 'gamma': 0.1, 
 'warmup_factor': 0.3333333333333333,
 'warmup_iters': 500, 
 'warmup_method': 'linear', 
 'base_lrs': [0.005, 0.005,  0.005, 0.01, ......, 0.005, 0.005, 0.005, 0.005, 0.01], 
 'last_epoch': 99999}
'''

继续看一下这个base_lrs的信息

print(len(scheduler["base_lrs"]))  # 返回115,→115个数组成的一个列表
(4)net[“iteration”] 详解
print(net["iteration"])  # 返回 9999 ,它是一个具体的数字

二、关于.pth 文件的总结

它是一个包含 四组 “key-value”的字典(这四组不一定同时存在,需要看具体情况,比如可能没有学习率衰减策略scheduler),类型分别如下:

在这里插入图片描述

其中

(1)net[“model”] 就相当于是 前面文章中说到的 net.state_dict() 返回的那个字典;

(2)net[“optimizer”] 就相当于是 前面文章中说到的 optimizer.state_dict() 返回的那个字典

参考链接:https://blog.csdn.net/qq_33590958/article/details/103543128

  • 14
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: .pth.tar 是一个文件扩展名,通常用于存储 PyTorch 模型的状态字典或是训练过程的检查点。而 .pth则是一个文件扩展名,用于存储 PyTorch 模型的脚本形式。 将 .pth.tar 和 .pth 文件互换,需要进行相应的处理来转换它们之间的格式。 如果想将 .pth.tar 文件转换为 .pth 文件,可以按照以下步骤进行: 1. 解压缩 .pth.tar 文件。可以使用压缩软件或是命令行工具,将 .pth.tar 文件解压缩到一个指定的目录。 2. 运行脚本导出模型。通常,在解压缩后的目录,会有一个 Python 脚本文件,用于导出模型。可以运行这个脚本,将模型保存为 .pth 文件。 如果想将 .pth 文件转换为 .pth.tar 文件,可以按照以下步骤进行: 1. 创建一个新的目录。在某个目录下,创建一个新的文件夹,用于存放转换后的 .pth.tar 文件。 2. 复制 .pth 文件到新目录。将原始的 .pth 文件复制到新的目录。 3. 打包 .pth 文件为 .pth.tar 文件。使用压缩软件或是命令行工具,将新目录下的 .pth 文件打包为一个 .pth.tar 文件。 通过上述步骤,可以将 .pth.tar 和 .pth 文件互换格式,使其适应不同的需求。需要注意的是,在转换格式过程,可能需要进行其他操作或是添加一些额外的信息,以确保文件的完整性和正确性。 ### 回答2: .pth.tar和.pthPython常见的文件类型,用于存储模型或包的扩展库。两者基本上具有相同的功能,可以在Python加载模型或包。 .pthPython的一个文件扩展名,用于存储模块的路径信息。当Python解释器在导入模块时,会在指定的路径搜索这些.pth文件,并将路径添加到模块搜索路径。 .pth.tar是对.pth进行压缩和归档的文件格式。它是对.pth文件进行打包,以便更好地管理和传输。 要互换.pth.tar和.pth,需要使用相关的工具进行转换。可以使用tar命令将.pth.tar文件解压缩为.pth文件,然后使用.pth文件。同样,可以使用tar命令将.pth文件打包为.pth.tar文件。 例如,要将.pth.tar文件解压缩为.pth文件,可以使用以下命令: ``` tar -xvf filename.pth.tar ``` 这将解压缩文件并生成.pth文件。 同样,要将.pth文件打包为.pth.tar文件,可以使用以下命令: ``` tar -cvf filename.pth.tar filename.pth ``` 这将打包.pth文件并生成.pth.tar文件。 总之,.pth.tar和.pth文件在使用上没有本质的区别,只是一个被压缩和归档,另一个是单独的文件。根据需要,可以根据具体情况互相转换使用。 ### 回答3: .pth.tar 和 .pthPython 用于模块导入的文件扩展名。它们可以互相转换使用。 .pth.tar 文件是一个压缩文件,通常用于打包一组相关的 Python 模块。它可以包含多个 .py 文件或目录,并且可以通过解压缩操作获取其的内容。如果要将 .pth.tar 文件转换为 .pth 文件,可以执行以下步骤: 1. 使用解压缩工具(如 WinRAR 或 7-Zip)打开 .pth.tar 文件。 2. 从压缩文件提取出所有的 .py 文件或目录。 3. 创建一个名为 XXX.pth文件,其 XXX 是你想要的模块名或者功能名。 4. 在 .pth 文件,每行写入一个 .py 文件或目录的路径,表示要导入的模块或包的位置。 5. 将 .pth 文件放置在 Python 的 site-packages 目录下或者你自己配置的模块搜索路径。 现在,你可以通过 import XXX 来导入相关的模块或包。 与此相反,如果你有一个 .pth 文件,想把它转换为一个 .pth.tar 打包文件,可以执行以下步骤: 1. 创建一个名为 YYY.pth文件,其 YYY 是你想要的打包文件的名字。 2. 在 .pth 文件,每行写入一个 .py 文件或目录的路径,表示要包含在打包文件的模块或包的位置。 3. 打开命令行界面,进入到包含 .pth 文件的目录。 4. 执行以下命令来创建一个 .pth.tar 文件:tar -cf YYY.pth.tar YYY.pth 5. 现在你会在当前目录下看到一个名为 YYY.pth.tar 的文件,它是一个打包了 .pth 文件所列模块或包的压缩文件。 无论是 .pth.tar 还是 .pth 文件,它们都是为了方便 Python 的模块导入而存在的,可以根据需要在这两个格式之间进行转换。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值