深度学习卷积网络浮点计算量和参数量的计算(附Pytorch代码)

144 篇文章 73 订阅
60 篇文章 10 订阅

作者 | 涤生
编辑 | 3D视觉开发者社区
✨如果觉得文章内容不错,别忘了三连支持下哦😘~

FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

不同的网络算子有不同的计算量,其计算方式也不相同。在卷积神经网络中,主要以卷积层和全连接层为主,其他算子的操作一般比较小,可以忽略不计,算作误差。二者其计算量的计算方式如下:

卷积

K K K表示卷积核大小, C i n C_{in} Cin C o u t C_{out} Cout表示输入和输出通道数, H o u t H_{out} Hout W o u t W_{out} Wout表示输出特征图大小。

情况一:乘法+加法+bias

其中n个数相加需要n-1次加法运算。
[( C i n C_{in} Cin* K K K * K K K)+( C i n C_{in} Cin K K K * K K K)] * H o u t H_{out} Hout * W o u t W_{out} Wout= 2 C i n C_{in} Cin K 2 K^2 K2 * H o u t H_{out} Hout * W o u t W_{out} Wout * C o u t C_{out} Cout

情况二:乘法+加法,无bias

[( C i n C_{in} Cin* K K K * K K K)+( C i n C_{in} Cin * ( K K K * K K K-1)+( C i n C_{in} Cin-1)] * H o u t H_{out} Hout * W o u t W_{out} Wout * C o u t C_{out} Cout=( 2 C i n C_{in} Cin K 2 K^2 K2 -1) * H o u t H_{out} Hout * W o u t W_{out} Wout * C o u t C_{out} Cout

情况三:乘法,不计入加法
[( C i n C_{in} Cin* K K K * K K K) * H o u t H_{out} Hout * W o u t W_{out} Wout * C o u t C_{out} Cout= C i n C_{in} Cin K 2 K^2 K2 * H o u t H_{out} Hout * W o u t W_{out} Wout * C o u t C_{out} Cout

参数量:
C i n C_{in} Cin* K h K_{h} Kh * K w K_{w} Kw * C o u t C_{out} Cout

全连接层

I I I表示输入维度, O O O表示输出维度

情况一:乘法+加法+bias

( I I I+ I I I)* O O O=2 I I I O O O

情况二:乘法+加法,无bias

( I I I+ I I I-1)* O O O=(2 I I I-1) * O O O

情况三:乘法,不计入加法

I I I* O O O= I I I O O O

参数量:

I I I* O O O

计算代码–以PyTorch框架为例

第一个推荐的计算库是 thop, 简单好用

① 安装:

推荐从作者的github直接安装最新版本。

1 pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

② 使用:

1 #计算
2 from torchvision.models import resnet50
3 from thop import profile
4 model = resnet50()
5 dummy_input = torch.randn(1, 3, 224, 224)
6 macs, params = profile(model, inputs=(dummy_input, ))
7 #输出
8 from thop import clever_format
9 macs, params = clever_format([macs, params], "%.3f")

第二个推荐微软的nni包

nni中有个计算网络模型计算量和参数量的小工具,使用起来也非常简单,并能输出每一层的计算量和参数量。

① 安装:

1 pip install --upgrade nni

② 使用:

1 from torchvision.models import resnet18
2 from nni.compression.pytorch.utils.counter import  count_flops_params
3 model = resnet18()
4 dummy_input = torch.randn(1, 3, 224, 224)
5 flops, params, results = count_flops_params(model, dummy_input)

在这里插入图片描述

总结:

浮点计算量在一定程度上能估计模型的计算复杂度,但是却不一定能代表真实的推理时间。

这主要是因为,不同的硬件或框架对不同算子的优化程度不同,即使计算量大但计算速度也可能比较快,所以FLOPs这一指标只是网络计算复杂度的一个参考,有理论意义。

本文作者原创于知乎,内容有参考https://www.zhihu.com/question/65305385的回答。

版权声明:本文为奥比中光3D视觉开发者社区特约作者授权原创发布,未经授权不得转载,本文仅做学术分享,版权归原作者所有,若涉及侵权内容请联系删文

3D视觉开发者社区是由奥比中光给所有开发者打造的分享与交流平台,旨在将3D视觉技术开放给开发者。平台为开发者提供3D视觉领域免费课程、奥比中光独家资源与专业技术支持。

点击加入3D视觉开发者社区,和开发者们一起讨论分享吧~
也可移步微信关注官方公众号:3D视觉开发者社区 ,获取更多干货知识哦~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值