FLOPS:(floating point operations per second)浮点运算次数每秒,是计算速度,是衡量硬件性能的指标。
FLOPs:(floating point operations)小写的s表复数的意思,是计算量,用来衡量算法/模型的复杂度。
"""
Provides access to per-submodule model flop count obtained by
tracing a model with pytorch's jit tracing functionality. By default,
comes with standard flop counters for a few common operators.
Note that:
1. Flop is not a well-defined concept. We just produce our best estimate.
2. We count one fused multiply-add as one flop.
Handles for additional operators may be added, or the default ones
overwritten, using the ``.set_op_handle(name, func)`` method.
See the method documentation for details.
Flop counts can be obtained as:
* ``.total(module_name="")``: total flop count for the module
* ``.by_operator(module_name="")``: flop counts for the module, as a Counter
over different operator types
* ``.by_module()``: Counter of flop counts for all submodules
* ``.by_module_and_operator()``: dictionary indexed by descendant of Counters
over different operator types
An operator is treated as within a module if it is executed inside the
module's ``__call__`` method. Note that this does not include calls to
other methods of the module or explicit calls to ``module.forward(...)``.
Example usage:
>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(in_features=1000, out_features=10)
... self.conv = nn.Conv2d(
... in_channels=3, out_channels=10, kernel_size=1
... )
... self.act = nn.ReLU()
... def forward(self, x):
... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> flops = FlopCountAnalysis(model, inputs)
>>> flops.total()
13000
>>> flops.total("fc")
10000
>>> flops.by_operator()
Counter({"addmm" : 10000, "conv" : 3000})
>>> flops.by_module()
Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0})
>>> flops.by_module_and_operator()
{"" : Counter({"addmm" : 10000, "conv" : 3000}),
"fc" : Counter({"addmm" : 10000}),
"conv" : Counter({"conv" : 3000}),
"act" : Counter()
}
"""