Pytorch 计算模型的FLOPs和参数量

安装:pip install ptflops

单独使用:

import torch
from ptflops import get_model_complexity_info
flops, params = get_model_complexity_info(model, [1,32,32], as_strings=True, print_per_layer_stat=True)
print(flops, params)

批量处理:

import torch, os
from ptflops import get_model_complexity_info


class Cal_Params():
    def __init__(self, model_name, device='cuda'):
        self.model_name = model_name
        self.path = r'models/{}'.format(model_name)
        self.model = get_model(self.path).to(torch.device(device))
        self.input_size = (1, self.model.size, self.model.size)

    def get_params(self, save_file, verbose=True):
        filepath = os.path.join(self.path, 'params.txt')
        f = open(filepath, 'w')
        flops, params = get_model_complexity_info(self.model, self.input_size, as_strings=True,
                                                  print_per_layer_stat=True, ost=f)
        display('%9s | %11s | %9s' % (self.model_name, flops, params), file=save_file, verbose=verbose)


def display(string, file=None, verbose=True):
    if file != None:
        print(string, file=file)
    if verbose:
        print(string)
        # devnull = open(os.devnull, 'w')
        # print(string, file=devnull)
        # devnull.close()


if __name__ == '__main__':
    save_file = open('all_params.txt', 'w')
    model_names = ['model1', 'model2', 'model3']
    losses = ['L1', 'L2']
    display('%9s | %11s | %9s' % ('Model', 'FLOPs', 'Params'), file=save_file)
    try:
        for model_name in model_names:
            cp = Cal_Params(model_name)
            cp.get_params(save_file)
    except Exception as e:
        print('Error: {}'.format(e))
    finally:
        save_file.close()

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值