1. 安装 ptflops
pip install ptflops
2. 计算params和FLOPs
from ptflops import get_model_complexity_info
def print_time_paramter_complexity(net, input_size):
macs, params = get_model_complexity_info(net, input_size, as_strings=True, print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
3. 多输入的params和FLOPs
class Model(nn.Module):
def __init__(self):
super.__init__()
def forward(self, xs):
x1, x2 = xs[0], xs[1]
return x1 + x2
def prepare_input(resolution, input_size):
"""
input_size: including batch_size.
For threeD, input_size = [(2, 3, 80, 192, 160), (2, 1, 80, 192, 160)]
"""
x1 = torch.FloatTensor(input_size[0])
x2 = torch.FloatTensor(input_size[1])
return dict(x=(x1, x2))
def print_time_paramter_complexity(net, input_size):
macs, params = get_model_complexity_info(net, input_size, as_strings=True, print_per_layer_stat=True, verbose=True,
input_constructor=prepare_input)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
if __name__ == "__main__":
net = Model()
print_time_paramter_complexity(net,
input_size = [(2, 3, 80, 192, 160), (2, 1, 80, 192, 160)])