1.stat(cpu统计)
pip install torchstat
from torchstat import stat
stat(model, (3, 32, 32)) #统计模型的参数量和FLOPs,(3,32,32)是输入图像的size
结果:
问题:当网络中有自定义参数时,就很有可能漏掉那部分参数对应的统计量;stat好像不支持双输入。
2.summary网络结构对应参数(cuda上面统计)
pip install torchsummary
from torchsummary import summary
summary(model,input_size=(3,32,32))
问题:当网络中有自定义参数时,就很有可能漏掉那部分参数。
结果:
3.统计flops和参数量
pip install thop
from thop import profile
dummy_input = torch.randn(1, 3, 32, 32)