这两个函数可以用来计算flops、madd、内存、参数、形状等
例如对于一个网络net1
使用summary与stat函数对其性能进行检查
其中(1,28,28)是网络输入形状(28*28的图片)
from torchstat import stat
from torchsummary import summary
import torch
def main():
model = Net_2()
stat(model, (1, 28, 28))
if __name__ == '__main__':
main()
便可以对各种要素进行展示
注:
是因为展平层和dropout层没有形状参数这些
from torchstat import stat
from torchsummary import summary
import torch
def main():
model = Net_2()
summary(model.cuda(), (1, 28, 28), )
if __name__ == '__main__':
main()
summary函数只对形状以及参数个数进行展示。