在PyTorch中,可以使用torchstat这个库来查看网络模型的一些信息,包括总的参数量params、MAdd、显卡内存占用量和FLOPs等。
使用前需要先安装torchstat包,如下:
pip install torchstat
示例代码如下:
from torchstat import stat
from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d
model = resnet50()
stat(model, (3, 224, 224))
如果只是想看模型的总参数量,可以通过如下方式:
total = sum([param.nelement() for param in model.parameters()])
print("Number of parameters: %.2fM" % (total/1e6))
stat打印完整信息如下: