学习笔记
三个工具包
summary、thop、stat。
torchsummary.summary
from torchsummary import summary
model = YourModel()
summary(model, input_size=(3, 84, 84), device='cuda')
如果有多个input时
from torchsummary import summary
model = YourModel()
summary(model, input_size=[(3, 84, 84), (1, 3, 84, 84)], device='cuda')
但是这样会报错
TypeError: can’t multiply sequence by non-int of type ‘tuple’
方法:将torchsummary.py的第100行这样改
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
参考:
https://blog.csdn.net/kakangel/article/details/130795893
https://github.com/sksq96/pytorch-summary/issues/90
https://blog.csdn.net/qq_43733107/article/details/126508616
但是此方法不能计算FLOPs
torchstat.stat
from torchstat import stat
model = YourModel()
stat(model, input_size=(3, 84, 84))
但是不支持多个输入
thop.profile
一个输入
from thop import profile
model = YourModel()
input = torch.randn(1, 3, 300, 300).cuda()
flop, para = profile(model, inputs=(input, )) # 必须加上逗号,否者会报错
print('Flops:',"%.2fM" % (flop/1e6), 'Params:',"%.2fM" % (para/1e6))
total = sum([param.nelement() for param in model.parameters()])
print('Number of parameter: %.2fM' % (total/1e6))
多个输入
from thop import profile
model = YourModel()
input1 = torch.randn(1, 3, 84, 84).cuda()
input2 = torch.randn(1, 1, 3, 84, 84).cuda()
flop, para = profile(model1, inputs=(input1, input2))
print('Flops:',"%.2fM" % (flop/1e6), 'Params:',"%.2fM" % (para/1e6))
total = sum([param.nelement() for param in model.parameters()])
print('Number of parameter: %.2fM' % (total/1e6))
参考
https://www.cnblogs.com/ycycn/p/17928507.html#4pytorch_71