使用场景:
(pytorch)获取网络的参数,使用torchstat中的stat函数
问题描述
在使用stat进行参数打印时,我们通常会得到一大串数据,可是有一些的参数不甚了解,在经过查询后,得到了一些相关解释,故特此记录。
下载方式
pip install torchstat
常用用法
from torchstat import stat
from network import net #此处net为我们使用的网络
model = net()
stat(model,(3,256,256)) #格式为 stat(网络名称,(波段数,图像大小))
常见输出
【Total params】网络的整体参数量
【Total memory】模型进行推理时候所需的内存
【Total Flops】网络完成的浮点运算
【Total MAdd】网络完成的乘加操作的数量。一次乘加=一次乘法+一次加法,可以大致认为(Flops ≈2*MAdd)
【MemR+W】MemR+W = MemRead + MemWrite
(MemRead:网络运行时,从内存中读取的大小)
(MemWrite:网络运行时,写入到内存中的大小)
适用范围(支持的图层):
FLOPS和FLOPs的区别:
FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。