Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)

相关文章

Pytorch学习笔记(一):torch.cat()模块的详解
Pytorch学习笔记(二):nn.Conv2d()函数详解
Pytorch学习笔记(三):nn.BatchNorm2d()函数详解
Pytorch学习笔记(四):nn.MaxPool2d()函数详解
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
Pytorch学习笔记(六):view()和nn.Linear()函数详解
Pytorch学习笔记(七):F.softmax()和F.log_softmax函数详解
Pytorch学习笔记(八):nn.ModuleList和nn.Sequential函数详解
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)

1.方法

torchstat.stat:计算Pytorch模型的FLOPs、模型参数量、MAdd、模型显存占用量
thop:工具包仅支持FLOPs和参数量的计算
ptflops:统计 参数量 和 FLOPs
torchsummary:用来计算网络的计算参数等信息

下载不下来用
-i http://pypi.douban.com/simple --trusted-host pypi.douban.com

2.计算换算

FLOPs是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量,用以衡量算法/模型复杂度。
MACs 每秒执行的定点乘累加操作次数的缩写,它是衡量计算机定点处理能力的量,这个量经常用在那些需要大量定点乘法累加运算的科学运算中,记为MACs。

  • 一个 MFLOPS (megaFLOPS) 等于每秒1百万 (=10^6) 次的浮点运算,
  • 一个 GFLOPS (gigaFLOPS) 等于每秒10亿 (=10^9) 次的浮点运算,
  • 一个 TFLOPS (teraFLOPS) 等于每秒1万亿 (=10^12) 次的浮点运算,
  • 一个 PFLOPS (petaFLOPS) 等于每秒1千万亿 (=10^15) 次的浮点运算
  • GMAC=0.5GFLOPs

3.具体代码

代码1:

def shufflenet_1x(num_classes=10):
    return ShuffleNetV2(1, num_classes)
model=shufflenet_1x()
#a=torch.randn(1,3,224,224)
# print(model(a))
"""通过torchstat.stat 可以查看网络模型的参数量和计算复杂度FLOPs"""
from torchstat import stat
# stat(model,(3,224,224))
#=======================================================================================================================================================
#Total params: 561,706
#-------------------------------------------------------------------------------------------------------------------------------------------------------
#Total memory: 6.88MB
#Total MAdd: 79.01MMAdd
#Total Flops: 39.96MFlops
#Total MemR+W: 14.08MB
"""thop工具包仅支持FLOPs和参数量的计算"""
# from thop import profile
# from thop import clever_format
# input=torch.randn(1,3,224,224)
# flops, params = profile(model, inputs=(input, ))
# print(flops, params) # 46388784.0 561706.0
# flops, params = clever_format([flops, params], "%.3f")
# print(flops, params) # 46.389M 561.706K
"""ptflops统计 参数量 和 FLOPs"""
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
                                       print_per_layer_stat=True, verbose=True)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))
#Computational complexity:       0.05 GMac
#Number of parameters:           1.26 M
"""torchsummary 用来计算网络的计算参数等信息"""
from torchsummary import summary
model2 = ShuffleNetV2(scale=1.5, in_channels=3, c_tag=0.5, num_classes=10, activation=nn.ReLU,SE=False, residual=False)
a=torch.randn(1,3,224,224)
summary(model2.cuda(),input_size=(3,224,224))
# ================================================================
#Total params: 2,489,770
#Trainable params: 2,489,770
#Non-trainable params: 0
#----------------------------------------------------------------
#Input size (MB): 0.57
#Forward/backward pass size (MB): 62.77
#Params size (MB): 9.50
#Estimated Total Size (MB): 72.84
#----------------------------------------------------------------

代码2:

from nanodet.model.arch import build_model
from nanodet.util import cfg, load_config
from torchvision.models import resnet50
import torch
config = 'D:/pycharm/nanodet-main/config/nanodet-m.yml'
load_config(cfg, config)
model = build_model(cfg.model)
model1 = resnet50()
"""通过torchstat.stat 可以查看网络模型的参数量和计算复杂度FLOPs"""
from thop import profile
input = torch.randn(1,3,224,224)
flops,params = profile(model,inputs=(input,))
print('the flops is {}G,the params is {}M'.format(round(flops/(10**9),2), round(params/(10**6),2))) # 4111514624.0 25557032.0 res50


  • 13
    点赞
  • 62
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ZZY_dl

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值