模型的计算复杂度可以通过:参数量和FLOPs(浮点运算的总量)
def print_info(model,input):
"""
打印模型的信息
model=nn.Sequential(
nn.Conv2d(3,64,3,1,1),
nn.Conv2d(64,64,3,1,1),
nn.Conv2d(64,3,3,1,1)
)
x=torch.rand(1,3,96,96)
print_info(model,x)
"""
from torchstat import stat
from torchinfo import summary
if torch.is_tensor(input):
if input.dim()>3:
print('summary model info')
summary(model,input.shape)
input=input.squeeze(0)
print('stat model info ')
stat(model,input.shape)
else:
print('输入信息错误')