torch_flops: 准确捕获forward中所有算子的FLOPs计算库

目录

为什么需要torch_flops库

torch_flops库介绍

torch_flops的使用和对比

尚存的限制

最后


为什么需要torch_flops库

在对比神经网络模型轻量程度时,通常会使用FLOPs(Floating Point Operations. 参考链接)指标(但要注意的是FLOPs小并不代表模型速度快)。目前已有很多FLOPs计算库,如thopptflopstorchstat,但尝试过六七个库之后,发现它们几乎都只能支持nn.Module子类的计算,比如nn.Linear,而不会计算forward()中直接写出的运算符,比如torch.softmax()tensor.exp()+@等。这就会引发FLOPs计算偏低或偏高的问题:

  1. FLOPs计算可能偏低:如果自己实现一个Attention计算,其中forward会包含exp、softmax、matmul等运算(参考timm中的attention实现),那么exp、softmax、matmul等非nn.Module运算符的FLOPs都会被忽略,导致计算的FLOPs比实际FLOPs偏低。

  2. FLOPs计算可能偏高:如果一个nn.Linear在模型的__init__()中定义了,但forward()中没有使用,还是会计算此Linear的FLOPs,比实际FLOPs偏高。

为了解决不能准确计算forward()过程实际FLOPs的问题,笔者基于pytorch的Symbolic Tracing写了一套FLOPs计算代码库:torch_flops。经过测试验证,torch_flops能够实现计算forward中所有运算符的FLOPs,从而解决了以上两个问题。项目开源地址:https://github.com/zugexiaodui/torch_flops。

torch_flops库介绍

  • 计算原理

FLOPs计算库的整体思路都是首先捕获运算算子,然后根据运算算子的类别、参数、输入特征图维度信息,由算子的FLOPs计算公式给出该FLOPs值,最后求和即可。这里最关键的是如何准确捕获forward中的所有算子,包括继承自nn.Module的子类(这一种是最简单的,也是大部分库都实现了的)、pytorch的函数运算(如 torch.softmax()、+),以及tensor本身的方法(如tensor.exp()、tensor.add())。

在pytorch 1.8之后,新增了torch.fx模块,其中包含Symbolic Tracing功能,该功能能够把模型描述为一个图(GraphModule),图中的每个节点都代表一个算子,并且可以获取这些算子的类别、参数、输入张量等信息,甚至修改算子的实际运算过程(可以参考torch.fx文档)。因此,torch_flops的整体思路就是首先把模型转换为GraphModule,然后遍历每个节点,同时调用已定义好的该节点的FLOPs计算代码,最终就可以获取所有节点的FLOPs了。

  • 安装torch_flops

torch_flops已上传至pypi仓库,可通过pip install torch_flops直接下载安装。或访问torch_flops仓库下载代码,通过python setup.py install安装。

  • 依赖库

torch>=1.8:所使用的pytorch版本需要支持torch.fx,因此需要torch>= 1.8.0。代码只在 torch>=2.0.0上测试过,建议最好使用2.0以上版本。

tabulate:需tabulate库以支持表格打印功能,版本无特殊要求,最新即可。

torch_flops的使用和对比

此处给出torch_flops的使用代码及与其他3个star数量较多的FLOPs计算库thop、ptflops和torchanalyse的对比结果,展示如何使用torch_flops库的同时,能够体现出torch_flops准确捕获forward所有算子的优势。 此处只给出主要代码,完整代码可见torch_flops/blob/main/compare.py

首先,定义一个简单模型,只包含一个线性层nn.Linear,默认带有bias项(args.linear_no_bias=False)。在forward()中,除了让输入的x进入线性层,还可选择地运行x+=1,作为forward()中的非Module运算。是否运行x+=1由命令行解析的传入参数--add_one 是否为True决定。

class SimpleModel(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        self.layer = nn.Linear(5, 4, bias=not args.linear_no_bias)
        self.__add_one = args.add_one

    def forward(self, x: Tensor):
        x = self.layer(x)
        if self.__add_one:
            x += 1.
        return x
  • 使用torch_flops计算FLOPs的示例

    model = SimpleModel(inp_args)
    x = torch.randn(1, 5)
    y = model(x)
    print("*" * 40 + " Model " + "*" * 40)
    print(model)
    print(y)
    print("=" * 80)

    # =========
    print("*" * 40 + " torch_flops " + "*" * 40)
    flops_counter = TorchFLOPsByFX(model)
    # flops_counter.graph_model.graph.print_tabular()
    flops_counter.propagate(x)
    flops_counter.print_result_table()
    flops_1 = flops_counter.print_total_flops(show=False)
    print(f"torch_flops: {flops_1} FLOPs")
    print("=" * 80)

其中,关键的语句只有3行:

  1. flops_counter = TorchFLOPsByFX(model):把模型传入torch_flops定义的计算FLOPs的类TorchFLOPsByFX。这个过程会调用pytorch的symbolic_trace()方法,也是最可能出错的步骤。如果出错,则需要规范forward()代码 ,使其不存在动态控制流

  2. flops_counter.propagate(x):把输入送进网络前向传播,同时就会计算所有算子节点的FLOPs。

  3. flops_counter.print_result_table():以表格形式打印所有运算节点的类型和FLOPs。如果某些节点还没有FLOPs计算函数,则会进行提示,并标记not recognized。 可以自行修改源码补充对应的FLOPs计算函数或在仓库提交issue。

对于以上定义的模型和代码,在__add_one为False时输出结果如下:

**************************************** Model ****************************************
SimpleModel(
  (layer): Linear(in_features=5, out_features=4, bias=True)
)
tensor([[-0.2077,  0.2623,  1.3978, -0.4170]], grad_fn=<AddmmBackward0>)
================================================================================
**************************************** torch_flops ****************************************
===========  ===========  ===========  =====================  =======
node_name    node_op      op_target    nn_module_stack[-1]      flops
===========  ===========  ===========  =====================  =======
x            placeholder  x                                         0
layer        call_module  layer        Linear                      40
output       output       output                                    0
===========  ===========  ===========  =====================  =======
torch_flops: 40 FLOPs

当加入x+=1操作,即__add_one为True时,可以看到所计算的FLOPs增加了对应的加法运算量:

**************************************** Model ****************************************
SimpleModel(
  (layer): Linear(in_features=5, out_features=4, bias=True)
)
tensor([[1.0426, 0.6963, 1.7114, 1.6526]], grad_fn=<AddBackward0>)
================================================================================
**************************************** torch_flops ****************************************
===========  =============  =======================  =====================  =======
node_name    node_op        op_target                nn_module_stack[-1]      flops
===========  =============  =======================  =====================  =======
x            placeholder    x                                                     0
layer        call_module    layer                    Linear                      40
add          call_function  <built-in function add>                               4
output       output         output                                                0
===========  =============  =======================  =====================  =======
torch_flops: 44 FLOPs

从而验证了torch_flops能够准确捕获并计算forward()所有算子FLOPs的功能。

  • 与其他FLOPs计算库的对比

而对于torchanalyse、thop和ptflops库,当__add_one为False即不执行x+=1时计算的FLOPs (或MACs,乘加数量)结果如下:

**************************************** torchanalyse ****************************************
torchanalyse: 40 FLOPs
================================================================================
**************************************** thop ****************************************
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
thop: 20 MACs
================================================================================
**************************************** ptflops ****************************************
Warning: module SimpleModel is treated as a zero-op.
SimpleModel(
  24, 100.000% Params, 24.0 Mac, 100.000% MACs, 
  (layer): Linear(24, 100.000% Params, 24.0 Mac, 100.000% MACs, in_features=5, out_features=4, bias=True)
)
ptflops: 24 MACs
================================================================================

当改变__add_one为True即在forward()中增加x+=1运算时, FLOPs计算结果并没有改变:

**************************************** torchanalyse ****************************************
torchanalyse: 40 FLOPs
================================================================================
**************************************** thop ****************************************
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
thop: 20 MACs
================================================================================
**************************************** ptflops ****************************************
Warning: module SimpleModel is treated as a zero-op.
SimpleModel(
  24, 100.000% Params, 24.0 Mac, 100.000% MACs, 
  (layer): Linear(24, 100.000% Params, 24.0 Mac, 100.000% MACs, in_features=5, out_features=4, bias=True)
)
ptflops: 24 MACs
================================================================================

笔者曾试过六七个FLOPs库,对于这种在forward()中增加运算的情况几乎都无能为力,这也是笔者不得不自己开发一个新库的原因。

尚存的限制

pytorch_flops能够准确捕获forward()中所有算子并计算FLOPs是该库最大的特色和优势,但同时该库的使用也存在一些限制。

  1. 需要较高版本的pytorch。torch.fx功能在1.8.0才出现,笔者代码是在2.0版本上编写测试的,因此最好使用高于2.0版本的pytorch,这可能会导致在某些复杂的运行环境下与其他一些库不兼容。但可以增加一个独立的pytorch 2.0环境专门用于安装torch_flops。另外,笔者曾两次发现对于一模一样的代码,pytorch 2.0可能会比pytorch 1.8占用更少的显存,尚未研究发生的条件和原因。而且pytorch 2.0增加了很多实用功能(如torch.compile),新版本才是未来。
  2. 模型定义不能使用动态控制流,模型图在定义完成后需要静态确定。即forward()中不能存在由if 实时决定的运算,注意上述代码中if self.__add_one不算,因为__add_one在模型初始化后已经有确定数值了,整个graph也就确定了;除此之外还不能存在for循环等,可以参考limitations-of-symbolic-tracing了解更多。
  3. 可计算FLOPs的算子尚未支持太多。当前(0.1.1版本)支持的算子可见torch_flops/flops_ops.py文件中末尾的MODULE_FLOPs_MAPPING、FUNCTION_FLOPs_MAPPING和METHOD_FLOPs_MAPPING。目前框架已完成,补充新算子成为需要花费精力的增量工作。如果需要某些特定算子的FLOPs计算,可以自行下载源码在flops_ops.py中增加,或在仓库中提出issue。

最后

欢迎对torch_flops多多star!欢迎更多朋友提出宝贵建议、一起开发完善torch_flops,规范FLOPs计算过程,增强对比公平性。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值