相关文章
Pytorch学习笔记(一):torch.cat()模块的详解
Pytorch学习笔记(二):nn.Conv2d()函数详解
Pytorch学习笔记(三):nn.BatchNorm2d()函数详解
Pytorch学习笔记(四):nn.MaxPool2d()函数详解
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
Pytorch学习笔记(六):view()和nn.Linear()函数详解
Pytorch学习笔记(七):F.softmax()和F.log_softmax函数详解
Pytorch学习笔记(八):nn.ModuleList和nn.Sequential函数详解
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
✅💖⚠️▶️➡️🌿🍀🍄🌟⭐❄️✅💖⚠️▶️➡️🌿🍀🍄🌟⭐❄️✅💖⚠️▶️➡️🌿🍀🍄🌟⭐❄️✅💖⚠️
1.方法及详细介绍🌟
在深度学习模型分析领域,存在多种实用的工具和方法,以下是几种常用于分析 Pytorch 模型的工具:
- torchstat.stat:这是一款功能全面的工具,其主要功能在于能够精确计算 Pytorch 模型的多个关键指标。它可以计算出模型的浮点运算次数(FLOPs),这一指标对于评估模型的计算复杂度至关重要,能够帮助开发者了解模型在不同硬件平台上的运行效率。同时,它还能统计模型的参数量,模型参数量的多少直接影响模型的存储需求和训练时间。此外,它还可以计算出乘加运算量(MAdd)以及模型显存占用量,为模型的优化和部署提供了详细的数据支持。
- thop:该工具包在 Pytorch 模型分析中也具有重要作用。虽然其功能相对较为专注,仅支持对模型的 FLOPs 和参数量进行计算,但在这两个关键指标的计算上具有较高的准确性和效率。在一些对模型计算资源消耗和复杂度评估的场景中,thop 能够快速地为开发者提供所需的数据,帮助其对模型的规模和计算需求有一个清晰的认识,以便进行后续的优化和调整。
- ptflops:此工具主要致力于统计 Pytorch 模型的参数量和 FLOPs。通过对模型结构的深入分析,它能够准确地获取模型中各个层的参数数量以及在模型运行过程中所涉及的浮点运算次数。这对于评估模型的复杂度和性能表现具有重要意义,开发者可以根据这些统计数据来比较不同模型架构的优劣,从而选择最适合特定任务的模型结构。
- torchsummary:它是一个专门用来计算网络相关计算参数等信息的工具。在 Pytorch 模型开发过程中,开发者可以利用 torchsummary 快速获取网络的详细信息,如各层的输出形状、参数数量等。这些信息对于理解模型的结构和功能、调试模型以及优化模型的性能都具有重要的辅助作用,能够帮助开发者更高效地进行模型的开发和改进工作。
这些工具在 Pytorch 模型的开发、优化和分析过程中都发挥着不可或缺的作用,开发者可以根据具体的需求和场景选择合适的工具来获取所需的模型信息。
下载不下来用
-i http://pypi.douban.com/simple --trusted-host pypi.douban.com
2.计算换算🌟
FLOPs是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量,用以衡量算法/模型复杂度。
MACs 每秒执行的定点乘累加操作次数的缩写,它是衡量计算机定点处理能力的量,这个量经常用在那些需要大量定点乘法累加运算的科学运算中,记为MACs。
- 一个 MFLOPS (megaFLOPS) 等于每秒1百万 (=10^6) 次的浮点运算,
- 一个 GFLOPS (gigaFLOPS) 等于每秒10亿 (=10^9) 次的浮点运算,
- 一个 TFLOPS (teraFLOPS) 等于每秒1万亿 (=10^12) 次的浮点运算,
- 一个 PFLOPS (petaFLOPS) 等于每秒1千万亿 (=10^15) 次的浮点运算
- GMAC=0.5GFLOPs
3.具体代码🌟
代码1:
def shufflenet_1x(num_classes=10):
return ShuffleNetV2(1, num_classes)
model=shufflenet_1x()
#a=torch.randn(1,3,224,224)
# print(model(a))
"""通过torchstat.stat 可以查看网络模型的参数量和计算复杂度FLOPs"""
from torchstat import stat
# stat(model,(3,224,224))
#=======================================================================================================================================================
#Total params: 561,706
#-------------------------------------------------------------------------------------------------------------------------------------------------------
#Total memory: 6.88MB
#Total MAdd: 79.01MMAdd
#Total Flops: 39.96MFlops
#Total MemR+W: 14.08MB
"""thop工具包仅支持FLOPs和参数量的计算"""
# from thop import profile
# from thop import clever_format
# input=torch.randn(1,3,224,224)
# flops, params = profile(model, inputs=(input, ))
# print(flops, params) # 46388784.0 561706.0
# flops, params = clever_format([flops, params], "%.3f")
# print(flops, params) # 46.389M 561.706K
"""ptflops统计 参数量 和 FLOPs"""
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
#Computational complexity: 0.05 GMac
#Number of parameters: 1.26 M
"""torchsummary 用来计算网络的计算参数等信息"""
from torchsummary import summary
model2 = ShuffleNetV2(scale=1.5, in_channels=3, c_tag=0.5, num_classes=10, activation=nn.ReLU,SE=False, residual=False)
a=torch.randn(1,3,224,224)
summary(model2.cuda(),input_size=(3,224,224))
# ================================================================
#Total params: 2,489,770
#Trainable params: 2,489,770
#Non-trainable params: 0
#----------------------------------------------------------------
#Input size (MB): 0.57
#Forward/backward pass size (MB): 62.77
#Params size (MB): 9.50
#Estimated Total Size (MB): 72.84
#----------------------------------------------------------------
代码2:
from nanodet.model.arch import build_model
from nanodet.util import cfg, load_config
from torchvision.models import resnet50
import torch
config = 'D:/pycharm/nanodet-main/config/nanodet-m.yml'
load_config(cfg, config)
model = build_model(cfg.model)
model1 = resnet50()
"""通过torchstat.stat 可以查看网络模型的参数量和计算复杂度FLOPs"""
from thop import profile
input = torch.randn(1,3,224,224)
flops,params = profile(model,inputs=(input,))
print('the flops is {}G,the params is {}M'.format(round(flops/(10**9),2), round(params/(10**6),2))) # 4111514624.0 25557032.0 res50
✅💖⚠️▶️➡️🌿🍀🍄🌟⭐❄️✅💖⚠️▶️➡️🌿🍀🍄🌟⭐❄️✅💖⚠️▶️➡️🌿🍀🍄🌟⭐❄️✅💖⚠️