pytorch 计算网络模型的计算量FLOPs和参数量parameter之殊途同归


在进行论文撰写时,我们通常要通过计算网络模型的计算量FLOPs和参数量parameter来评估模型的性能,本文总结了几种常用的计算方式,大家可以尝试一下。
为了能够便于读者理解,我们选取pytorch自带的网络resnet34进行测试,也可自行更改为其他或所提网络

参数量方法一:pytorch自带方法,计算模型参数总量

from torchvision.models import resnet34

net=resnet34() # 注意:模型内部传参数和不传参数,输出的结果是不一样的
# 计算网络参数
total = sum([param.nelement() for param in net.parameters()])
# 精确地计算:1MB=1024KB=1048576字节
print('Number of parameter: % .4fM' % (total / 1e6))

输出:

Number of parameter:  21.7977M

参数量方法二: summary的使用:来自于torchinfo第三方库

torchinfo 的 summary 更加友好,我个人觉得是 print 和 torchsummary 的 summary 的结合体!推荐!!!

from torchvision.models import resnet34
import torch
from torchinfo import summary  # 注意:当使用from torchsummary import summary时,对应的summary应该为:summary(model, input_size=(3, 512, 512), batch_size=-1)
if __name__ == "__main__":
    model = resnet34()
    tmp_0 = model(torch.rand(1, 3, 224, 224).cuda())
    print(tmp_0.shape)

    summary(model, (1, 3, 224, 224))# summary的函数内部参数形式与导入的第三方库有关,否则报错

输出结果如下:
在这里插入图片描述
在这里插入图片描述

参数量方法三: summary的使用:来自于torchsummary第三方库

torchsummary 中的 summary 可以打印每一层的shape, 参数量,

from torchvision.models import resnet34
from torchsummary import summary
model = resnet34()
summary(model, input_size=(3, 512, 512), batch_size=-1)# 同样是summary函数,注意与方法二的区别

输出结果如下:

在这里插入图片描述
在这里插入图片描述

计算量方法一:thop的使用,输出计算量FLOPs和参数量parameter

注意区分FLOPs和FLOPS
FLOPs就是表示模型前向传播中计算MAC(乘法加法操作的次数),如果FLOPs的值越大,也从一定程度上说明模型越复杂,模型需要的计算力(算力)更高,因此对硬件的要求也就越高!

from torchvision.models import resnet34
import torch
from thop import profile
if __name__ == "__main__":
    # #call Transception_res

    model = resnet34()
    input = torch.randn(1, 3, 512, 512)
    Flops, params = profile(model, inputs=(input,)) # macs
    print('Flops: % .4fG'%(Flops / 1000000000))# 计算量
    print('params参数量: % .4fM'% (params / 1000000)) #参数量:等价与上面的summary输出的Total params值

该网络模型中包含该方法的计算:https://github.com/Barrett-python/DuAT/blob/main/DuAT.py

输出结果:输出为网络模型的总参数量(单位M,即百万)与计算量(单位G,即十亿)

Flops:  19.2174G
params参数量:  21.7977M

在这里插入图片描述
参考链接:

  1. CNN 模型的参数(parameters)数量和浮点运算数量(FLOPs)是怎么计算的https://blog.csdn.net/weixin_41010198/article/details/108104309
  2. 区分FLOPs和FLOPS:https://blog.csdn.net/IT_flying625/article/details/104898152
  3. pytorch得到模型的计算量和参数量https://blog.csdn.net/qq_35407318/article/details/109359006
  4. 轻量化网络中常使用的参数量和计算量评估;https://blog.csdn.net/weixin_46274756/article/details/130391999如下图所示
    在这里插入图片描述
  5. Pytorch 中打印网络结构及其参数的方法与实现https://blog.csdn.net/like_jmo/article/details/126903727
  6. CNN 模型所需的计算力flops是什么?怎么计算?https://zhuanlan.zhihu.com/p/137719986
  7. FLOPS的计算:https://blog.csdn.net/baidu_35848778/article/details/127571810
  • 9
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
可以使用下面的代码计算PyTorch模型FLOPs(浮点操作次数): ```python import torch from torch.autograd import Variable def print_model_parm_flops(model, input_size, custom_layers): multiply_adds = 1 params = 0 flops = 0 input = Variable(torch.rand(1, *input_size)) def register_hook(module): def hook(module, input, output): class_name = str(module.__class__).split(".")[-1].split("'")[0] if class_name == 'Conv2d': out_h, out_w = output.size()[2:] kernel_h, kernel_w = module.kernel_size in_channels = module.in_channels out_channels = module.out_channels if isinstance(module.padding, int): pad_h = pad_w = module.padding else: pad_h, pad_w = module.padding if isinstance(module.stride, int): stride_h = stride_w = module.stride else: stride_h, stride_w = module.stride params += out_channels * (in_channels // module.groups) * kernel_h * kernel_w flops += out_channels * (in_channels // module.groups) * kernel_h * kernel_w * out_h * out_w / (stride_h * stride_w) elif class_name == 'Linear': weight_flops = module.weight.nelement() * input[0].nelement() // module.weight.size(1) bias_flops = module.bias.nelement() flops = weight_flops + bias_flops params = weight_flops + bias_flops elif class_name in custom_layers: custom_flops, custom_params = custom_layers[class_name](module, input, output) flops += custom_flops params += custom_params else: print(f"Warning: module {class_name} not implemented") if not isinstance(module, torch.nn.Sequential) and \ not isinstance(module, torch.nn.ModuleList) and \ not (module == model): hooks.append(module.register_forward_hook(hook)) # loop through the model architecture and register hooks for each layer hooks = [] model.apply(register_hook) # run the input through the model model(input) # remove the hooks for hook in hooks: hook.remove() print(f"Number of parameters: {params}") print(f"Number of FLOPs: {flops}") return flops, params ``` 调用这个函数需要传入模型、输入大小和一个自定义图层字典,其中字典的键是自定义层的名称,值是一个函数,该函数接受模块,输入和输出,返回FLOPs数量。例如,如果您的模型包含一个名为MyLayer的自定义层,则可以将以下内容添加到字典中: ```python def my_layer_impl(module, input, output): # compute FLOPs and params for MyLayer flops = ... params = ... return flops, params custom_layers = {'MyLayer': my_layer_impl} ``` 使用示例: ```python import torchvision.models as models model = models.resnet18() input_size = (3, 224, 224) custom_layers = {} flops, params = print_model_parm_flops(model, input_size, custom_layers) ``` 该函数将打印出模型数量FLOPs

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

三少的笔记

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值