【目的】pytorch获取网络的参数量、MAdd、Flops
【可使用库】torchstat中的stat、thop中的profile
1 stat打印
【安装工具】pip install torchstat
【使用例子】我们的网络只有一层,该层的数据就是整个模型的数据。
这里并没有严格按照pytorch官方提供的公式计算,个人感觉不是很好记忆;这里是使用实际的例子,来将计算方式具体化,反向的去理解公式import torch import torch.nn as nn from torchstat import stat class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 2, kernel_size=7, stride=2, padding=3, bias=False) def forward(self, x): x = self.conv1(x) net = Net() stat(net, (3, 500, 500))
【打印内容】
stat命令打印的结果如上,我们分别分析下参数的计算方式:
- 【params】
网络主要的参数量
7*7*3*2=294(W*H*C_in*C_out),因为这里偏置设为False,所有不加上偏置的参数量- 【memory】
到stat
源码定义下查看,该参数的定义应为:节点推理时候所需的内存(具体计算公式本人暂不清楚,如果有了解的期待评论告知,多谢)
- 【Flops】
网络完成的浮点运算。这里计算以输出的Feature map为视角,其中每个元素的计算需要经历
((7*7*3)+(7*7*3-1))*(250*250*2) = 36625000 ~= 36.62 MFlops
((输出一个元素所经历的乘法次数)+(输出一个元素所经历的加法的个数))*(输出总共的元素的个数)- 【MAdd】
网络完成的乘加操作的数量。一次乘加=一次乘法+一次加法,所以可以粗略的认为:Flops ~=2*MAdd
(7*7*3)*(250*250*2) = 18375000 ~= 18.38 MMAdd- 【MemRead】
网络运行时,从内存中读取的大小 = 输入的特征图大小 + 网络参数的大小
((500*500*3) + (7*7*3*2))*4 = 3001176.0
这里乘以4,是因为假设这里的数是float32的,一个float32=4*byte- 【MemWrite】
网络运行时,写入到内存中的大小 = 输出的特征图大小
250*250*2*4 = 500000- 【MemR+W】
MemR+W = MemRead + MemWrite。在这里等于 3001176.0+500000 = 3501176.0
能够发现,按照公式计算的 Flops/MAdd 两个变量刚好反了,按道理pytorch应该不会出现如此明显的bug,但的确自己计算是按照定义计算的,这个问题就只能先保留在这
2023.03.24:与同事沟通发现,对于某一层的网络相关计算,池化操作不满足:Flops ~=2*MAdd(有兴趣的同学自己可打印查看下),但完整的网络模型中绝大部分为卷积,池化操作的影响可忽略不计
2 profile打印
可看到打印结果与stat相应数值大小基本一致
import torch import torch.nn as nn # from torchstat import stat from thop import profile class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 2, kernel_size=7, stride=2, padding=3, bias=False) def forward(self, x): x = self.conv1(x) # net = Net() # stat(net, (3, 500, 500)) input = torch.randn(1, 3, 500, 500) flops, params = profile(net, inputs=(input,)) print('FLOPs = ' + str(flops/1000**3) + 'G') print('Params = ' + str(params/1000**2) + 'M')