作者 | 涤生
编辑 | 3D视觉开发者社区
✨如果觉得文章内容不错,别忘了三连支持下哦😘~
FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。
不同的网络算子有不同的计算量,其计算方式也不相同。在卷积神经网络中,主要以卷积层和全连接层为主,其他算子的操作一般比较小,可以忽略不计,算作误差。二者其计算量的计算方式如下:
卷积
K K K表示卷积核大小, C i n C_{in} Cin和 C o u t C_{out} Cout表示输入和输出通道数, H o u t H_{out} Hout和 W o u t W_{out} Wout表示输出特征图大小。
情况一:乘法+加法+bias
其中n个数相加需要n-1次加法运算。
[(
C
i
n
C_{in}
Cin*
K
K
K *
K
K
K)+(
C
i
n
C_{in}
Cin
K
K
K *
K
K
K)] *
H
o
u
t
H_{out}
Hout *
W
o
u
t
W_{out}
Wout= 2
C
i
n
C_{in}
Cin
K
2
K^2
K2 *
H
o
u
t
H_{out}
Hout *
W
o
u
t
W_{out}
Wout *
C
o
u
t
C_{out}
Cout
情况二:乘法+加法,无bias
[( C i n C_{in} Cin* K K K * K K K)+( C i n C_{in} Cin * ( K K K * K K K-1)+( C i n C_{in} Cin-1)] * H o u t H_{out} Hout * W o u t W_{out} Wout * C o u t C_{out} Cout=( 2 C i n C_{in} Cin K 2 K^2 K2 -1) * H o u t H_{out} Hout * W o u t W_{out} Wout * C o u t C_{out} Cout
情况三:乘法,不计入加法
[(
C
i
n
C_{in}
Cin*
K
K
K *
K
K
K) *
H
o
u
t
H_{out}
Hout *
W
o
u
t
W_{out}
Wout *
C
o
u
t
C_{out}
Cout=
C
i
n
C_{in}
Cin
K
2
K^2
K2 *
H
o
u
t
H_{out}
Hout *
W
o
u
t
W_{out}
Wout *
C
o
u
t
C_{out}
Cout
参数量:
C
i
n
C_{in}
Cin*
K
h
K_{h}
Kh *
K
w
K_{w}
Kw *
C
o
u
t
C_{out}
Cout
全连接层
I I I表示输入维度, O O O表示输出维度
情况一:乘法+加法+bias
( I I I+ I I I)* O O O=2 I I I O O O
情况二:乘法+加法,无bias
( I I I+ I I I-1)* O O O=(2 I I I-1) * O O O
情况三:乘法,不计入加法
I I I* O O O= I I I O O O
参数量:
I I I* O O O
计算代码–以PyTorch框架为例
第一个推荐的计算库是 thop, 简单好用
① 安装:
推荐从作者的github直接安装最新版本。
1 pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git
② 使用:
1 #计算
2 from torchvision.models import resnet50
3 from thop import profile
4 model = resnet50()
5 dummy_input = torch.randn(1, 3, 224, 224)
6 macs, params = profile(model, inputs=(dummy_input, ))
7 #输出
8 from thop import clever_format
9 macs, params = clever_format([macs, params], "%.3f")
第二个推荐微软的nni包
nni中有个计算网络模型计算量和参数量的小工具,使用起来也非常简单,并能输出每一层的计算量和参数量。
① 安装:
1 pip install --upgrade nni
② 使用:
1 from torchvision.models import resnet18
2 from nni.compression.pytorch.utils.counter import count_flops_params
3 model = resnet18()
4 dummy_input = torch.randn(1, 3, 224, 224)
5 flops, params, results = count_flops_params(model, dummy_input)
总结:
浮点计算量在一定程度上能估计模型的计算复杂度,但是却不一定能代表真实的推理时间。
这主要是因为,不同的硬件或框架对不同算子的优化程度不同,即使计算量大但计算速度也可能比较快,所以FLOPs这一指标只是网络计算复杂度的一个参考,有理论意义。
本文作者原创于知乎,内容有参考https://www.zhihu.com/question/65305385的回答。
版权声明:本文为奥比中光3D视觉开发者社区特约作者授权原创发布,未经授权不得转载,本文仅做学术分享,版权归原作者所有,若涉及侵权内容请联系删文
3D视觉开发者社区是由奥比中光给所有开发者打造的分享与交流平台,旨在将3D视觉技术开放给开发者。平台为开发者提供3D视觉领域免费课程、奥比中光独家资源与专业技术支持。
点击加入3D视觉开发者社区,和开发者们一起讨论分享吧~
也可移步微信关注官方公众号:3D视觉开发者社区 ,获取更多干货知识哦~