这里介绍两种方法
1、使用相关的库torchsummary
from torchsummary import summary
net=net.to(torch.device("cpu"))#or cuda
summary(net,(4,228,912),device="cpu") #or cuda
统计结果比较详细,参数量、浮点数计算量、中间变量、train的变量数、保持不变的变量数,每一层的中间变量和类型都会详细列出
2、使用库thop
from thop import profile
net=net.cuda()
input= torch.ones([1,4,128,128]).cuda()
inputs=[]
inputs.append(input)
flops, params=profile(net,inputs)#,custom_ops={model.ResNet,countModel})
print("flops:{0:,} ".format(flops))
print("parms:{0:,}".format(params))
这个比较简单,最后简单的输出参数量和其中的浮点计算次数