【pytorch】使用stat、profile打印网络的参数量、Flops、MAdd、内存使用的情况

目的】pytorch获取网络的参数量、MAdd、Flops
可使用库】torchstat中的stat、thop中的profile

1 stat打印

安装工具】pip install torchstat
使用例子】我们的网络只有一层,该层的数据就是整个模型的数据。
这里并没有严格按照pytorch官方提供的公式计算,个人感觉不是很好记忆;这里是使用实际的例子,来将计算方式具体化,反向的去理解公式

import torch
import torch.nn as nn
from torchstat import stat

class Net(nn.Module):
 def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 2, kernel_size=7, stride=2, padding=3, bias=False)
    def forward(self, x):
       x = self.conv1(x)

net = Net()
stat(net, (3, 500, 500))

在这里插入图片描述
打印内容
stat命令打印的结果如上,我们分别分析下参数的计算方式:

  • 【params】
    网络主要的参数量
    7*7*3*2=294(W*H*C_in*C_out),因为这里偏置设为False,所有不加上偏置的参数量
  • 【memory】
    stat源码定义下查看,该参数的定义应为:节点推理时候所需的内存(具体计算公式本人暂不清楚,如果有了解的期待评论告知,多谢)
    在这里插入图片描述
  • 【Flops】
    网络完成的浮点运算。这里计算以输出的Feature map为视角,其中每个元素的计算需要经历
    ((7*7*3)+(7*7*3-1))*(250*250*2) = 36625000 ~= 36.62 MFlops
    ((输出一个元素所经历的乘法次数)+(输出一个元素所经历的加法的个数))*(输出总共的元素的个数)
  • 【MAdd】
    网络完成的乘加操作的数量。一次乘加=一次乘法+一次加法,所以可以粗略的认为:Flops ~=2*MAdd
    (7*7*3)*(250*250*2) = 18375000 ~= 18.38 MMAdd
  • 【MemRead】
    网络运行时,从内存中读取的大小 = 输入的特征图大小 + 网络参数的大小
    ((500*500*3) + (7*7*3*2))*4 = 3001176.0
    这里乘以4,是因为假设这里的数是float32的,一个float32=4*byte
  • 【MemWrite】
    网络运行时,写入到内存中的大小 = 输出的特征图大小
    250*250*2*4 = 500000
  • 【MemR+W】
    MemR+W = MemRead + MemWrite。在这里等于 3001176.0+500000 = 3501176.0

能够发现,按照公式计算的 Flops/MAdd 两个变量刚好反了,按道理pytorch应该不会出现如此明显的bug,但的确自己计算是按照定义计算的,这个问题就只能先保留在这

2023.03.24:与同事沟通发现,对于某一层的网络相关计算,池化操作不满足:Flops ~=2*MAdd(有兴趣的同学自己可打印查看下),但完整的网络模型中绝大部分为卷积,池化操作的影响可忽略不计

2 profile打印

可看到打印结果与stat相应数值大小基本一致

import torch
import torch.nn as nn
# from torchstat import stat
from thop import profile

class Net(nn.Module):
   def __init__(self):
       super(Net, self).__init__()
       self.conv1 = nn.Conv2d(3, 2, kernel_size=7, stride=2, padding=3, bias=False)
   def forward(self, x):
       x = self.conv1(x)

# net = Net()
# stat(net, (3, 500, 500))

input = torch.randn(1, 3, 500, 500)
flops, params = profile(net, inputs=(input,))
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')

在这里插入图片描述

  • 17
    点赞
  • 69
    收藏
    觉得还不错? 一键收藏
  • 20
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值