MMSeg分析Flops和Params

本文探讨了如何在PyTorch环境中计算深度学习模型的Flops(浮点运算次数)和Params(参数数量)。通过指定输入图片尺寸,可以使用特定命令进行计算。在分析过程中,需要注意配置文件的设置,避免输入参数与默认值冲突。文章还展示了如何针对模型的特定模块分析其计算量和参数量,以理解模型复杂度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Flops计算量,params参数量
在文件中

tools/analysis_tools/get_flops.py

利用以下命令实现

python tools/analysis_tools/get_flops.py configs/xxx/xxx-Net.py

后面可跟参数shape控制输入图片尺寸,例如

python tools/analysis_tools/get_flops.py configs/xxx/xxx-Net.py --shape 512 512
``
如下展示

```python
python tools/analysis_tools/get_flops.py configs/danet/danet_r50-d8_4xb4-40k_voc12aug-512x512.py
输出
==============================
Compute type: direct: randomly generate a picture
Input shape: (512, 512)
Flops: 0.211T
Params: 47.485M
==============================

坑点1

input_shape" and "inputs" cannot be both set.

在87行左右,由于配置文件在配置了data = model.data_preprocessor(data_batch),所有data中有数据,同时,input_shape通过默认参数得到,两个不能同时有值,所以将data注释掉。希望通过输入的参数计算。
在这里插入图片描述

接下来看一下这个get_model_complexity_info函数都输出的是什么

{'flops': 211043745792, 
'flops_str': '0.211T', 
'activations': 168120320, 
'activations_str': '0.168G', 
'params': 47484961, 
'params_str': '47.485M', 
'out_table': '', 'out_arch': ''}

如何计算某一模块的计算量和参数量呢?
主要看以下代码

    flop_handler = FlopAnalyzer(model, inputs)
    activation_handler = ActivationAnalyzer(model, inputs)

    flops = flop_handler.total()
    activations = activation_handler.total()
    params = parameter_count(model)['']
导入
from mmengine.analysis  import (ActivationAnalyzer, FlopAnalyzer, parameter_count)

这里看一下FlopAnalyzer是如何使用

Examples:
        >>> import torch.nn as nn
        >>> import torch
        >>> class TestModel(nn.Module):
        ...    def __init__(self):
        ...        super().__init__()
        ...        self.fc = nn.Linear(in_features=1000, out_features=10)
        ...        self.conv = nn.Conv2d(
        ...            in_channels=3, out_channels=10, kernel_size=1
        ...        )
        ...        self.act = nn.ReLU()
        ...    def forward(self, x):
        ...        return self.fc(self.act(self.conv(x)).flatten(1))
        >>> model = TestModel()
        >>> inputs = (torch.randn((1,3,10,10)),)
        >>> flops = FlopAnalyzer(model, inputs)
        >>> flops.total()
        13000
        >>> flops.total("fc")
        10000
        >>> flops.by_operator()
        Counter({"addmm" : 10000, "conv" : 3000})
        >>> flops.by_module()
        Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0})
        >>> flops.by_module_and_operator()
        {"" : Counter({"addmm" : 10000, "conv" : 3000}),
        "fc" : Counter({"addmm" : 10000}),
        "conv" : Counter({"conv" : 3000}),
        "act" : Counter()
        }
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

weightOneMillion

感谢未来的亿万富翁捧个钱场~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值