安装:pip install ptflops
单独使用:
import torch
from ptflops import get_model_complexity_info
flops, params = get_model_complexity_info(model, [1,32,32], as_strings=True, print_per_layer_stat=True)
print(flops, params)
批量处理:
import torch, os
from ptflops import get_model_complexity_info
class Cal_Params():
def __init__(self, model_name, device='cuda'):
self.model_name = model_name
self.path = r'models/{}'.format(model_name)
self.model = get_model(self.path).to(torch.device(device))
self.input_size = (1, self.model.size, self.model.size)
def get_params(self, save_file, verbose=True):
filepath = os.path.join(self.path, 'params.txt')
f = open(filepath, 'w')
flops, params = get_model_complexity_info(self.model, self.input_size, as_strings=True,
print_per_layer_stat=True, ost=f)
display('%9s | %11s | %9s' % (self.model_name, flops, params), file=save_file, verbose=verbose)
def display(string, file=None, verbose=True):
if file != None:
print(string, file=file)
if verbose:
print(string)
# devnull = open(os.devnull, 'w')
# print(string, file=devnull)
# devnull.close()
if __name__ == '__main__':
save_file = open('all_params.txt', 'w')
model_names = ['model1', 'model2', 'model3']
losses = ['L1', 'L2']
display('%9s | %11s | %9s' % ('Model', 'FLOPs', 'Params'), file=save_file)
try:
for model_name in model_names:
cp = Cal_Params(model_name)
cp.get_params(save_file)
except Exception as e:
print('Error: {}'.format(e))
finally:
save_file.close()