Pytorch-OpCounter: Pytorch平台计算模型#Parameters和FLOPS的工具包
OpCounter (Github地址:https://github.com/Lyken17/pytorch-OpCounter)除了能够统计各种模型结构的参数以及FLOPS, 还能为那些特殊的运算定制化统计规则,非常好用。
OpCounter的安装
方式1: pip install thop
方式2: pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git
使用示例
from torchvision.models import resnet50
from thop import profile, clevar_style
model = resnet50()
input = torch.randn(1, 3, 224, 224) # (batch_size, num_channel, Height, Width)
flops, params = profile(model, inputs=(input, ))
print('flops: {}, params: {}'.format(flops, params))
输出结果如下:
flops: 2914598912.0, params: 7978856.0
如果模型中有自定义的特殊运算类:ModuleName,为其定义的运算统计规则为count_model, 如下:
class ModuleName(nn.Module):
# your definition
def count_model(model, x, y):
# your rule here
则调用时,可以通过参数custom_ops
来定制:
flops, params = profile(model, inputs=(input, ),
custom_ops={ModuleName: count_model})
此外,clevar_style
可以对输出结果进行简单处理,以更好的展示:
flops, params = clever_format([flops, params], "%.3f")
print('flops: {}, params: {}'.format(flops, params))
实现原理
该工具为每一种基本操作都定义了参数统计和运算量计算,目前主要包含视觉方面的运算,包括各种卷积、激活函数、池化、批归一化等。例如最常见的二维卷积运算,它的统计代码如下所示:
def count_conv2d(m, x, y):
x = x[0]
cin = m.in_channels
cout = m.out_channels
kh, kw = m.kernel_size
batch_size = x.size()[0]
out_h = y.size(2)
out_w = y.size(3)
# ops per output element
# kernel_mul = kh * kw * cin
# kernel_add = kh * kw * cin - 1
kernel_ops = multiply_adds * kh * kw
bias_ops = 1 if m.bias is not None else 0
ops_per_element = kernel_ops + bias_ops
# total ops
# num_out_elements = y.numel()
output_elements = batch_size * out_w * out_h * cout
total_ops = output_elements * ops_per_element * cin // m.groups
m.total_ops = torch.Tensor([int(total_ops)])