模型的运算次数,可用 FLOPs衡量,也就是浮点运算次数(FLoating-point OPerations),表征的是模型的时间复杂度。模型空间复杂度通过Parameters反映,即模型的参数量。
最早是在2017年ICLR会议上由英伟达公司提出FLOPs的计算:从这张图中可以看出,随着模型层数不断加深,相应的flops增大,时间复杂度提高,但模型的验证集的错误率在不断减小。因此我们希望在模型加深获得更高精度的前提下尽可能减少模型的时间复杂度,加快模型的训练和预测时间。
单个卷积核的时间复杂度
M为每个卷积核输出特征图(Feature Map)的长宽,输出特征图尺寸由输入特征图尺寸X,卷积核尺寸K,填充层Padding,步长stride四个参数决定,公式表示为:
K为每个卷积核(kernel size)的大小
IC为输入通道数,OC为输出通道数
注1:为了简化表达式中的变量个数,这里统一假设输入和卷积核的形状都是正方形。
注2:严格来讲每层应该还包含 1 个 bias参数,为了简洁就省略了。
模型整体时间复杂度
D为模型全部卷积层数
依据层内相乘,层间相加的准则,将其累计起来就是整个模型所有卷积层的时间复杂度。
除此之外。模型的其它层结构(激活函数层,上下采样,batch normalization,池化层等)同样具有时间复杂度,只不过整体来看卷积层所占比重最大。
常用计算FLOPs的工具
1 stat打印
import torch
import torch.nn as nn
from torchstat import stat
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(